diff --git a/Cargo.lock b/Cargo.lock index 171b37d0f27e..8dadc79bd3d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -530,8 +530,9 @@ dependencies = [ "napi", "napi-build", "napi-derive", + "react_compiler", "serde", - "swc_core", + "serde_json", "swc_ecma_react_compiler", "swc_malloc", "tracing", @@ -773,7 +774,7 @@ dependencies = [ "cap-primitives", "cap-std", "io-lifetimes 2.0.4", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -790,7 +791,7 @@ dependencies = [ "maybe-owned", "rustix 1.1.2", "rustix-linux-procfs", - "windows-sys 0.52.0", + "windows-sys 0.59.0", "winx", ] @@ -1966,6 +1967,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] [[package]] @@ -2129,7 +2131,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -2152,7 +2154,7 @@ checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", "rustix 1.1.2", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2247,7 +2249,7 @@ checksum = "94e7099f6313ecacbe1256e8ff9d617b75d1bcb16a6fddef94866d225a01a14a" dependencies = [ "io-lifetimes 2.0.4", "rustix 1.1.2", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2613,6 +2615,15 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + [[package]] name = "hstr" version = "3.0.4" @@ -3032,7 +3043,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2285ddfe3054097ef4b2fe909ef8c3bcd1ea52a8f0d274416caebeef39f04a65" dependencies = [ "io-lifetimes 2.0.4", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4686,6 +4697,137 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "react_compiler" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_ast", + "react_compiler_diagnostics", + "react_compiler_hir", + "react_compiler_inference", + "react_compiler_lowering", + "react_compiler_optimization", + "react_compiler_reactive_scopes", + "react_compiler_ssa", + "react_compiler_typeinference", + "react_compiler_validation", + "regex", + "serde", + "serde_json", +] + +[[package]] +name = "react_compiler_ast" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "serde", + "serde_json", + "similar", + "walkdir", +] + +[[package]] +name = "react_compiler_diagnostics" +version = "0.1.0" +dependencies = [ + "serde", +] + +[[package]] +name = "react_compiler_hir" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_diagnostics", + "serde", + "serde_json", +] + +[[package]] +name = "react_compiler_inference" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_diagnostics", + "react_compiler_hir", + "react_compiler_lowering", + "react_compiler_optimization", + "react_compiler_ssa", + "react_compiler_utils", +] + +[[package]] +name = "react_compiler_lowering" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_ast", + "react_compiler_diagnostics", + "react_compiler_hir", + "serde_json", +] + +[[package]] +name = "react_compiler_optimization" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_diagnostics", + "react_compiler_hir", + "react_compiler_lowering", + "react_compiler_ssa", +] + +[[package]] +name = "react_compiler_reactive_scopes" +version = "0.1.0" +dependencies = [ + "hmac", + "indexmap 2.12.0", + "react_compiler_ast", + "react_compiler_diagnostics", + "react_compiler_hir", + "serde_json", + "sha2", +] + +[[package]] +name = "react_compiler_ssa" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_diagnostics", + "react_compiler_hir", + "react_compiler_lowering", +] + +[[package]] +name = "react_compiler_typeinference" +version = "0.1.0" +dependencies = [ + "react_compiler_diagnostics", + "react_compiler_hir", + "react_compiler_ssa", +] + +[[package]] +name = "react_compiler_utils" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", +] + +[[package]] +name = "react_compiler_validation" +version = "0.1.0" +dependencies = [ + "indexmap 2.12.0", + "react_compiler_diagnostics", + "react_compiler_hir", +] + [[package]] name = "redox_syscall" version = "0.5.3" @@ -4944,7 +5086,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.4.15", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4957,7 +5099,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.11.0", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -5486,7 +5628,7 @@ dependencies = [ "cfg-if", "libc", "psm", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -5552,6 +5694,12 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "swc" version = "60.0.0" @@ -6515,8 +6663,17 @@ dependencies = [ name = "swc_ecma_react_compiler" version = "17.0.0" dependencies = [ + "indexmap 2.12.0", + "react_compiler", + "react_compiler_ast", + "react_compiler_diagnostics", + "react_compiler_hir", + "serde", + "serde_json", + "swc_atoms", "swc_common", "swc_ecma_ast", + "swc_ecma_codegen", "swc_ecma_parser", "swc_ecma_visit", "testing", @@ -7600,7 +7757,7 @@ dependencies = [ "fd-lock", "io-lifetimes 2.0.4", "rustix 0.38.44", - "windows-sys 0.52.0", + "windows-sys 0.59.0", "winx", ] @@ -9287,7 +9444,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.48.0", + "windows-sys 0.61.2", ] [[package]] @@ -9647,7 +9804,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f3fd376f71958b862e7afb20cfe5a22830e1963462f3a17f49d82a6c1d1f42d" dependencies = [ "bitflags 2.10.0", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] diff --git a/bindings/binding_react_compiler_node/Cargo.toml b/bindings/binding_react_compiler_node/Cargo.toml index 64f37cf09845..c4f496735246 100644 --- a/bindings/binding_react_compiler_node/Cargo.toml +++ b/bindings/binding_react_compiler_node/Cargo.toml @@ -19,15 +19,10 @@ napi-build = { workspace = true } backtrace = { workspace = true } napi = { workspace = true, features = ["napi3", "serde-json"] } napi-derive = { workspace = true, features = ["type-def"] } +react_compiler = { path = "../../crates/react_compiler" } serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } tracing = { workspace = true, features = ["release_max_level_info"] } - -swc_core = { path = "../../crates/swc_core", features = [ - "allocator_node", - "common_sourcemap", - "ecma_ast", - "ecma_parser", -] } swc_ecma_react_compiler = { path = "../../crates/swc_ecma_react_compiler" } swc_malloc = { path = "../../crates/swc_malloc" } diff --git a/bindings/binding_react_compiler_node/src/lib.rs b/bindings/binding_react_compiler_node/src/lib.rs index 12ae4f151b13..06714b4e21a1 100644 --- a/bindings/binding_react_compiler_node/src/lib.rs +++ b/bindings/binding_react_compiler_node/src/lib.rs @@ -21,3 +21,11 @@ fn init() { })); } } + +/// Output returned by the native React Compiler binding. +#[napi(object)] +pub struct TransformOutput { + pub code: String, + pub map: Option, + pub diagnostics: Vec, +} diff --git a/bindings/binding_react_compiler_node/src/support.rs b/bindings/binding_react_compiler_node/src/support.rs index 9d3e2f2b336e..a6615f5952f8 100644 --- a/bindings/binding_react_compiler_node/src/support.rs +++ b/bindings/binding_react_compiler_node/src/support.rs @@ -1,8 +1,98 @@ use napi::bindgen_prelude::*; -use swc_core::{ - common::{sync::Lrc, FileName, SourceMap}, - ecma::{ast::EsVersion, parser::Syntax}, -}; +use react_compiler::entrypoint::plugin_options::{CompilerTarget, GatingConfig, PluginOptions}; +use serde::Deserialize; +use swc_ecma_react_compiler::{self, SourceParser, SourceSyntax}; + +#[derive(Clone, Copy, Debug, Deserialize)] +enum ParserSyntax { + #[serde(rename = "ecmascript")] + EcmaScript, + #[serde(rename = "typescript")] + TypeScript, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ParserOptions { + syntax: Option, + jsx: Option, + tsx: Option, + decorators: Option, +} + +impl ParserOptions { + fn into_source_parser(self) -> SourceParser { + let mut parser = SourceParser::default(); + + if let Some(syntax) = self.syntax { + parser.syntax = match syntax { + ParserSyntax::EcmaScript => SourceSyntax::EcmaScript, + ParserSyntax::TypeScript => SourceSyntax::TypeScript, + }; + } + if let Some(jsx) = self.jsx { + parser.jsx = jsx; + } + if let Some(tsx) = self.tsx { + parser.tsx = tsx; + } + if let Some(decorators) = self.decorators { + parser.decorators = decorators; + } + + parser + } +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +struct TransformOptions { + parser: Option, + filename: Option, + is_dev: Option, + compilation_mode: Option, + panic_threshold: Option, + target: Option, + gating: Option, + enable_reanimated: Option, +} + +impl TransformOptions { + fn into_plugin_options(self, source_code: &str) -> PluginOptions { + PluginOptions { + should_compile: true, + enable_reanimated: self.enable_reanimated.unwrap_or(false), + is_dev: self.is_dev.unwrap_or(false), + filename: self.filename, + compilation_mode: self + .compilation_mode + .unwrap_or_else(|| String::from("infer")), + panic_threshold: self.panic_threshold.unwrap_or_else(|| String::from("none")), + target: self + .target + .unwrap_or_else(|| CompilerTarget::Version(String::from("19"))), + gating: self.gating, + dynamic_gating: None, + no_emit: false, + output_mode: None, + eslint_suppression_rules: None, + flow_suppressions: true, + ignore_use_no_forget: false, + custom_opt_out_directives: None, + environment: Default::default(), + source_code: Some(source_code.to_string()), + profiling: false, + debug: false, + } + } + + fn source_parser(&self) -> SourceParser { + self.parser + .clone() + .map(ParserOptions::into_source_parser) + .unwrap_or_default() + } +} struct IsReactCompilerRequiredTask { code: String, @@ -22,27 +112,81 @@ impl Task for IsReactCompilerRequiredTask { } } +struct TransformTask { + code: String, + options: Buffer, +} + +#[napi] +impl Task for TransformTask { + type JsValue = crate::TransformOutput; + type Output = crate::TransformOutput; + + fn compute(&mut self) -> napi::Result { + let code = std::mem::take(&mut self.code); + let options = deserialize_transform_options(self.options.as_ref())?; + transform_inner(&code, options) + } + + fn resolve(&mut self, _env: napi::Env, output: Self::Output) -> napi::Result { + Ok(output) + } +} + +fn decode_code(code: Buffer) -> String { + String::from_utf8_lossy(code.as_ref()).into_owned() +} + +fn deserialize_transform_options(input: &[u8]) -> napi::Result { + if input.is_empty() { + return Ok(TransformOptions::default()); + } + + serde_json::from_slice(input).map_err(|err| { + napi::Error::from_reason(format!("failed to parse transform options: {err}")) + }) +} + +fn transform_inner(code: &str, options: TransformOptions) -> napi::Result { + let parser = options.source_parser(); + let options = options.into_plugin_options(code); + + let result = + swc_ecma_react_compiler::try_transform_source_to_code_with_parser(code, options, parser) + .map_err(napi::Error::from_reason)?; + + Ok(crate::TransformOutput { + code: result.code, + map: result.map, + diagnostics: result.diagnostics, + }) +} + fn is_react_compiler_required_inner(code: &str) -> bool { - let cm = Lrc::new(SourceMap::default()); - let fm = cm.new_source_file(FileName::Anon.into(), code.to_string()); - - let program = swc_core::ecma::parser::parse_file_as_program( - &fm, - Syntax::Typescript(swc_core::ecma::parser::TsSyntax { - decorators: true, - tsx: true, - ..Default::default() - }), - EsVersion::latest(), - None, - &mut vec![], - ); - - let Ok(program) = program else { - return false; + swc_ecma_react_compiler::is_required_source_with_parser(code, SourceParser::default()) + .unwrap_or(false) +} + +#[napi] +fn transform( + code: Buffer, + options: Buffer, + signal: Option, +) -> AsyncTask { + let task = TransformTask { + code: decode_code(code), + options, }; - swc_ecma_react_compiler::fast_check::is_required(&program) + AsyncTask::with_optional_signal(task, signal) +} + +#[napi] +fn transform_sync(code: Buffer, options: Buffer) -> napi::Result { + transform_inner( + &decode_code(code), + deserialize_transform_options(options.as_ref())?, + ) } #[napi] @@ -50,16 +194,14 @@ fn is_react_compiler_required( code: Buffer, signal: Option, ) -> AsyncTask { - let code = String::from_utf8_lossy(code.as_ref()).into_owned(); - - let task = IsReactCompilerRequiredTask { code }; + let task = IsReactCompilerRequiredTask { + code: decode_code(code), + }; AsyncTask::with_optional_signal(task, signal) } #[napi] pub fn is_react_compiler_required_sync(code: Buffer) -> napi::Result { - let code = String::from_utf8_lossy(code.as_ref()).into_owned(); - - Ok(is_react_compiler_required_inner(&code)) + Ok(is_react_compiler_required_inner(&decode_code(code))) } diff --git a/crates/react_compiler/Cargo.toml b/crates/react_compiler/Cargo.toml new file mode 100644 index 000000000000..5403957c4c29 --- /dev/null +++ b/crates/react_compiler/Cargo.toml @@ -0,0 +1,23 @@ +[package] +description = "Vendored React Compiler core from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_ast = { path = "../react_compiler_ast" } +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } +react_compiler_inference = { path = "../react_compiler_inference" } +react_compiler_lowering = { path = "../react_compiler_lowering" } +react_compiler_optimization = { path = "../react_compiler_optimization" } +react_compiler_reactive_scopes = { path = "../react_compiler_reactive_scopes" } +react_compiler_ssa = { path = "../react_compiler_ssa" } +react_compiler_typeinference = { path = "../react_compiler_typeinference" } +react_compiler_validation = { path = "../react_compiler_validation" } +indexmap = { workspace = true } +regex = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true, features = ["raw_value"] } diff --git a/crates/react_compiler/src/debug_print.rs b/crates/react_compiler/src/debug_print.rs new file mode 100644 index 000000000000..3d2314849c15 --- /dev/null +++ b/crates/react_compiler/src/debug_print.rs @@ -0,0 +1,736 @@ +use react_compiler_diagnostics::CompilerError; +use react_compiler_hir::{ + environment::Environment, + print::{self, PrintFormatter}, + BasicBlock, BlockId, HirFunction, Instruction, ParamPattern, Place, Terminal, +}; + +// ============================================================================= +// DebugPrinter struct — thin wrapper around PrintFormatter for HIR-specific +// logic +// ============================================================================= + +struct DebugPrinter<'a> { + fmt: PrintFormatter<'a>, +} + +impl<'a> DebugPrinter<'a> { + fn new(env: &'a Environment) -> Self { + Self { + fmt: PrintFormatter::new(env), + } + } + + // ========================================================================= + // Function + // ========================================================================= + + fn format_function(&mut self, func: &HirFunction) { + self.fmt.indent(); + self.fmt.line(&format!( + "id: {}", + match &func.id { + Some(id) => format!("\"{}\"", id), + None => "null".to_string(), + } + )); + self.fmt.line(&format!( + "name_hint: {}", + match &func.name_hint { + Some(h) => format!("\"{}\"", h), + None => "null".to_string(), + } + )); + self.fmt.line(&format!("fn_type: {:?}", func.fn_type)); + self.fmt.line(&format!("generator: {}", func.generator)); + self.fmt.line(&format!("is_async: {}", func.is_async)); + self.fmt + .line(&format!("loc: {}", print::format_loc(&func.loc))); + + // params + self.fmt.line("params:"); + self.fmt.indent(); + for (i, param) in func.params.iter().enumerate() { + match param { + ParamPattern::Place(place) => { + self.fmt.format_place_field(&format!("[{}]", i), place); + } + ParamPattern::Spread(spread) => { + self.fmt.line(&format!("[{}] Spread:", i)); + self.fmt.indent(); + self.fmt.format_place_field("place", &spread.place); + self.fmt.dedent(); + } + } + } + self.fmt.dedent(); + + // returns + self.fmt.line("returns:"); + self.fmt.indent(); + self.fmt.format_place_field("value", &func.returns); + self.fmt.dedent(); + + // context + self.fmt.line("context:"); + self.fmt.indent(); + for (i, place) in func.context.iter().enumerate() { + self.fmt.format_place_field(&format!("[{}]", i), place); + } + self.fmt.dedent(); + + // aliasing_effects + match &func.aliasing_effects { + Some(effects) => { + self.fmt.line("aliasingEffects:"); + self.fmt.indent(); + for (i, eff) in effects.iter().enumerate() { + self.fmt + .line(&format!("[{}] {}", i, self.fmt.format_effect(eff))); + } + self.fmt.dedent(); + } + None => self.fmt.line("aliasingEffects: null"), + } + + // directives + self.fmt.line("directives:"); + self.fmt.indent(); + for (i, d) in func.directives.iter().enumerate() { + self.fmt.line(&format!("[{}] \"{}\"", i, d)); + } + self.fmt.dedent(); + + // return_type_annotation + self.fmt.line(&format!( + "returnTypeAnnotation: {}", + match &func.return_type_annotation { + Some(ann) => ann.clone(), + None => "null".to_string(), + } + )); + + self.fmt.line(""); + self.fmt.line("Blocks:"); + self.fmt.indent(); + for (block_id, block) in &func.body.blocks { + self.format_block(block_id, block, &func.instructions); + } + self.fmt.dedent(); + self.fmt.dedent(); + } + + // ========================================================================= + // Block + // ========================================================================= + + fn format_block( + &mut self, + block_id: &BlockId, + block: &BasicBlock, + instructions: &[Instruction], + ) { + self.fmt + .line(&format!("bb{} ({}):", block_id.0, block.kind)); + self.fmt.indent(); + + // preds + let preds: Vec = block.preds.iter().map(|p| format!("bb{}", p.0)).collect(); + self.fmt.line(&format!("preds: [{}]", preds.join(", "))); + + // phis + self.fmt.line("phis:"); + self.fmt.indent(); + for phi in &block.phis { + self.format_phi(phi); + } + self.fmt.dedent(); + + // instructions + self.fmt.line("instructions:"); + self.fmt.indent(); + for (index, instr_id) in block.instructions.iter().enumerate() { + let instr = &instructions[instr_id.0 as usize]; + self.format_instruction(instr, index); + } + self.fmt.dedent(); + + // terminal + self.fmt.line("terminal:"); + self.fmt.indent(); + self.format_terminal(&block.terminal); + self.fmt.dedent(); + + self.fmt.dedent(); + } + + // ========================================================================= + // Phi + // ========================================================================= + + fn format_phi(&mut self, phi: &react_compiler_hir::Phi) { + self.fmt.line("Phi {"); + self.fmt.indent(); + self.fmt.format_place_field("place", &phi.place); + self.fmt.line("operands:"); + self.fmt.indent(); + for (block_id, place) in &phi.operands { + self.fmt.line(&format!("bb{}:", block_id.0)); + self.fmt.indent(); + self.fmt.format_place_field("value", place); + self.fmt.dedent(); + } + self.fmt.dedent(); + self.fmt.dedent(); + self.fmt.line("}"); + } + + // ========================================================================= + // Instruction + // ========================================================================= + + fn format_instruction(&mut self, instr: &Instruction, index: usize) { + self.fmt.line(&format!("[{}] Instruction {{", index)); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", instr.id.0)); + self.fmt.format_place_field("lvalue", &instr.lvalue); + self.fmt.line("value:"); + self.fmt.indent(); + // For the HIR printer, inner functions are formatted via format_function + self.fmt.format_instruction_value( + &instr.value, + Some(&|fmt: &mut PrintFormatter, func: &HirFunction| { + // We need to recursively format the inner function + // Use a temporary DebugPrinter that shares the formatter state + let mut inner = DebugPrinter { + fmt: PrintFormatter { + env: fmt.env, + seen_identifiers: std::mem::take(&mut fmt.seen_identifiers), + seen_scopes: std::mem::take(&mut fmt.seen_scopes), + output: Vec::new(), + indent_level: fmt.indent_level, + }, + }; + inner.format_function(func); + // Write the output lines into the parent formatter + for line in &inner.fmt.output { + fmt.line_raw(line); + } + // Copy back the seen state + fmt.seen_identifiers = inner.fmt.seen_identifiers; + fmt.seen_scopes = inner.fmt.seen_scopes; + }), + ); + self.fmt.dedent(); + match &instr.effects { + Some(effects) => { + self.fmt.line("effects:"); + self.fmt.indent(); + for (i, eff) in effects.iter().enumerate() { + self.fmt + .line(&format!("[{}] {}", i, self.fmt.format_effect(eff))); + } + self.fmt.dedent(); + } + None => self.fmt.line("effects: null"), + } + self.fmt + .line(&format!("loc: {}", print::format_loc(&instr.loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + + // ========================================================================= + // Terminal + // ========================================================================= + + fn format_terminal(&mut self, terminal: &Terminal) { + match terminal { + Terminal::If { + test, + consequent, + alternate, + fallthrough, + id, + loc, + } => { + self.fmt.line("If {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.format_place_field("test", test); + self.fmt.line(&format!("consequent: bb{}", consequent.0)); + self.fmt.line(&format!("alternate: bb{}", alternate.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Branch { + test, + consequent, + alternate, + fallthrough, + id, + loc, + } => { + self.fmt.line("Branch {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.format_place_field("test", test); + self.fmt.line(&format!("consequent: bb{}", consequent.0)); + self.fmt.line(&format!("alternate: bb{}", alternate.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Logical { + operator, + test, + fallthrough, + id, + loc, + } => { + self.fmt.line("Logical {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("operator: \"{}\"", operator)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Ternary { + test, + fallthrough, + id, + loc, + } => { + self.fmt.line("Ternary {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Optional { + optional, + test, + fallthrough, + id, + loc, + } => { + self.fmt.line("Optional {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("optional: {}", optional)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Throw { value, id, loc } => { + self.fmt.line("Throw {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.format_place_field("value", value); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Return { + value, + return_variant, + id, + loc, + effects, + } => { + self.fmt.line("Return {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt + .line(&format!("returnVariant: {:?}", return_variant)); + self.fmt.format_place_field("value", value); + match effects { + Some(e) => { + self.fmt.line("effects:"); + self.fmt.indent(); + for (i, eff) in e.iter().enumerate() { + self.fmt + .line(&format!("[{}] {}", i, self.fmt.format_effect(eff))); + } + self.fmt.dedent(); + } + None => self.fmt.line("effects: null"), + } + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Goto { + block, + variant, + id, + loc, + } => { + self.fmt.line("Goto {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("block: bb{}", block.0)); + self.fmt.line(&format!("variant: {:?}", variant)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Switch { + test, + cases, + fallthrough, + id, + loc, + } => { + self.fmt.line("Switch {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.format_place_field("test", test); + self.fmt.line("cases:"); + self.fmt.indent(); + for (i, case) in cases.iter().enumerate() { + match &case.test { + Some(p) => { + self.fmt.line(&format!("[{}] Case {{", i)); + self.fmt.indent(); + self.fmt.format_place_field("test", p); + self.fmt.line(&format!("block: bb{}", case.block.0)); + self.fmt.dedent(); + self.fmt.line("}"); + } + None => { + self.fmt + .line(&format!("[{}] Default {{ block: bb{} }}", i, case.block.0)); + } + } + } + self.fmt.dedent(); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::DoWhile { + loop_block, + test, + fallthrough, + id, + loc, + } => { + self.fmt.line("DoWhile {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loop: bb{}", loop_block.0)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::While { + test, + loop_block, + fallthrough, + id, + loc, + } => { + self.fmt.line("While {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!("loop: bb{}", loop_block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::For { + init, + test, + update, + loop_block, + fallthrough, + id, + loc, + } => { + self.fmt.line("For {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("init: bb{}", init.0)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!( + "update: {}", + match update { + Some(u) => format!("bb{}", u.0), + None => "null".to_string(), + } + )); + self.fmt.line(&format!("loop: bb{}", loop_block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::ForOf { + init, + test, + loop_block, + fallthrough, + id, + loc, + } => { + self.fmt.line("ForOf {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("init: bb{}", init.0)); + self.fmt.line(&format!("test: bb{}", test.0)); + self.fmt.line(&format!("loop: bb{}", loop_block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::ForIn { + init, + loop_block, + fallthrough, + id, + loc, + } => { + self.fmt.line("ForIn {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("init: bb{}", init.0)); + self.fmt.line(&format!("loop: bb{}", loop_block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Label { + block, + fallthrough, + id, + loc, + } => { + self.fmt.line("Label {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("block: bb{}", block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Sequence { + block, + fallthrough, + id, + loc, + } => { + self.fmt.line("Sequence {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("block: bb{}", block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Unreachable { id, loc } => { + self.fmt.line(&format!( + "Unreachable {{ id: {}, loc: {} }}", + id.0, + print::format_loc(loc) + )); + } + Terminal::Unsupported { id, loc } => { + self.fmt.line(&format!( + "Unsupported {{ id: {}, loc: {} }}", + id.0, + print::format_loc(loc) + )); + } + Terminal::MaybeThrow { + continuation, + handler, + id, + loc, + effects, + } => { + self.fmt.line("MaybeThrow {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt + .line(&format!("continuation: bb{}", continuation.0)); + self.fmt.line(&format!( + "handler: {}", + match handler { + Some(h) => format!("bb{}", h.0), + None => "null".to_string(), + } + )); + match effects { + Some(e) => { + self.fmt.line("effects:"); + self.fmt.indent(); + for (i, eff) in e.iter().enumerate() { + self.fmt + .line(&format!("[{}] {}", i, self.fmt.format_effect(eff))); + } + self.fmt.dedent(); + } + None => self.fmt.line("effects: null"), + } + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Scope { + fallthrough, + block, + scope, + id, + loc, + } => { + self.fmt.line("Scope {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.format_scope_field("scope", *scope); + self.fmt.line(&format!("block: bb{}", block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::PrunedScope { + fallthrough, + block, + scope, + id, + loc, + } => { + self.fmt.line("PrunedScope {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.format_scope_field("scope", *scope); + self.fmt.line(&format!("block: bb{}", block.0)); + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + Terminal::Try { + block, + handler_binding, + handler, + fallthrough, + id, + loc, + } => { + self.fmt.line("Try {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("block: bb{}", block.0)); + self.fmt.line(&format!("handler: bb{}", handler.0)); + match handler_binding { + Some(p) => self.fmt.format_place_field("handlerBinding", p), + None => self.fmt.line("handlerBinding: null"), + } + self.fmt.line(&format!("fallthrough: bb{}", fallthrough.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + } + } +} + +// ============================================================================= +// Entry point +// ============================================================================= + +pub fn debug_hir(hir: &HirFunction, env: &Environment) -> String { + let mut printer = DebugPrinter::new(env); + printer.format_function(hir); + + // Print outlined functions (matches TS DebugPrintHIR.ts: printDebugHIR) + for outlined in env.get_outlined_functions() { + printer.fmt.line(""); + printer.format_function(&outlined.func); + } + + printer.fmt.line(""); + printer.fmt.line("Environment:"); + printer.fmt.indent(); + printer.fmt.format_errors(&env.errors); + printer.fmt.dedent(); + + printer.fmt.to_string_output() +} + +// ============================================================================= +// Error formatting (kept for backward compatibility) +// ============================================================================= + +pub fn format_errors(error: &CompilerError) -> String { + let env = Environment::new(); + let mut fmt = PrintFormatter::new(&env); + fmt.format_errors(error); + fmt.to_string_output() +} + +/// Format an HIR function into a reactive PrintFormatter. +/// This bridges the two debug printers so inner functions in +/// FunctionExpression/ObjectMethod can be printed within the reactive function +/// output. +pub fn format_hir_function_into(reactive_fmt: &mut PrintFormatter, func: &HirFunction) { + // Create a temporary DebugPrinter that shares the same environment + let mut printer = DebugPrinter { + fmt: PrintFormatter { + env: reactive_fmt.env, + seen_identifiers: std::mem::take(&mut reactive_fmt.seen_identifiers), + seen_scopes: std::mem::take(&mut reactive_fmt.seen_scopes), + output: Vec::new(), + indent_level: reactive_fmt.indent_level, + }, + }; + printer.format_function(func); + + // Write the output lines into the reactive formatter + for line in &printer.fmt.output { + reactive_fmt.line_raw(line); + } + // Copy back the seen state + reactive_fmt.seen_identifiers = printer.fmt.seen_identifiers; + reactive_fmt.seen_scopes = printer.fmt.seen_scopes; +} + +// ============================================================================= +// Helpers for effect formatting (kept for backward compatibility) +// ============================================================================= + +#[allow(dead_code)] +fn format_place_short(place: &Place, env: &Environment) -> String { + let ident = &env.identifiers[place.identifier.0 as usize]; + let name = match &ident.name { + Some(name) => name.value().to_string(), + None => String::new(), + }; + let scope = match ident.scope { + Some(scope_id) => format!(":{}", scope_id.0), + None => String::new(), + }; + format!("{}${}{}", name, place.identifier.0, scope) +} diff --git a/crates/react_compiler/src/entrypoint/compile_result.rs b/crates/react_compiler/src/entrypoint/compile_result.rs new file mode 100644 index 000000000000..7d7f2e792306 --- /dev/null +++ b/crates/react_compiler/src/entrypoint/compile_result.rs @@ -0,0 +1,299 @@ +use react_compiler_ast::{ + expressions::Identifier as AstIdentifier, patterns::PatternLike, statements::BlockStatement, +}; +use react_compiler_diagnostics::SourceLocation; +use react_compiler_hir::ReactFunctionType; +use serde::Serialize; + +use crate::timing::TimingEntry; + +/// Source location with index and filename fields for logger event +/// serialization. Matches the Babel SourceLocation format that the TS compiler +/// emits in logger events. +#[derive(Debug, Clone, Serialize)] +pub struct LoggerSourceLocation { + pub start: LoggerPosition, + pub end: LoggerPosition, + #[serde(skip_serializing_if = "Option::is_none")] + pub filename: Option, + #[serde(rename = "identifierName", skip_serializing_if = "Option::is_none")] + pub identifier_name: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub struct LoggerPosition { + pub line: u32, + pub column: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub index: Option, +} + +impl LoggerSourceLocation { + /// Create from a diagnostics SourceLocation, adding index and filename. + pub fn from_loc( + loc: &SourceLocation, + filename: Option<&str>, + start_index: Option, + end_index: Option, + ) -> Self { + Self { + start: LoggerPosition { + line: loc.start.line, + column: loc.start.column, + index: start_index, + }, + end: LoggerPosition { + line: loc.end.line, + column: loc.end.column, + index: end_index, + }, + filename: filename.map(|s| s.to_string()), + identifier_name: None, + } + } + + /// Create from a diagnostics SourceLocation without index or filename. + pub fn from_loc_simple(loc: &SourceLocation) -> Self { + Self { + start: LoggerPosition { + line: loc.start.line, + column: loc.start.column, + index: None, + }, + end: LoggerPosition { + line: loc.end.line, + column: loc.end.column, + index: None, + }, + filename: None, + identifier_name: None, + } + } +} + +/// A variable rename from lowering, serialized for the JS shim. +#[derive(Debug, Clone, Serialize)] +pub struct BindingRenameInfo { + pub original: String, + pub renamed: String, + #[serde(rename = "declarationStart")] + pub declaration_start: u32, +} + +/// Main result type returned by the compile function. +/// Serialized to JSON and returned to the JS shim. +#[derive(Debug, Serialize)] +#[serde(tag = "kind", rename_all = "lowercase")] +pub enum CompileResult { + /// Compilation succeeded (or no functions needed compilation). + /// `ast` is None if no changes were made to the program. + /// The AST is stored as a pre-serialized JSON string (RawValue) to avoid + /// double-serialization: File→Value→String becomes File→String directly. + Success { + ast: Option>, + events: Vec, + /// Unified ordered log interleaving events and debug entries. + /// Items appear in the order they were emitted during compilation. + /// The JS side uses this as the single source of truth (preferred over + /// separate events/debugLogs arrays). + #[serde(rename = "orderedLog", skip_serializing_if = "Vec::is_empty")] + ordered_log: Vec, + /// Variable renames from lowering, for applying back to the Babel AST. + /// Each entry maps an original binding name to its renamed version, + /// identified by the binding's declaration start position in the + /// source. + #[serde(skip_serializing_if = "Vec::is_empty")] + renames: Vec, + /// Timing data for profiling. Only populated when __profiling is + /// enabled. + #[serde(skip_serializing_if = "Vec::is_empty")] + timing: Vec, + }, + /// A fatal error occurred and panicThreshold dictates it should throw. + Error { + error: CompilerErrorInfo, + events: Vec, + #[serde(rename = "orderedLog", skip_serializing_if = "Vec::is_empty")] + ordered_log: Vec, + /// Timing data for profiling. Only populated when __profiling is + /// enabled. + #[serde(skip_serializing_if = "Vec::is_empty")] + timing: Vec, + }, +} + +/// An item in the ordered log, which can be either a logger event or a debug +/// entry. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum OrderedLogItem { + Event { event: LoggerEvent }, + Debug { entry: DebugLogEntry }, +} + +/// Structured error information for the JS shim. +#[derive(Debug, Clone, Serialize)] +pub struct CompilerErrorInfo { + pub reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub details: Vec, + /// When set, the JS shim should throw an Error with this exact message + /// instead of formatting through formatCompilerError(). This is used + /// for simulated unknown exceptions (throwUnknownException__testonly) + /// which in the TS compiler are plain Error objects, not CompilerErrors. + #[serde(rename = "rawMessage", skip_serializing_if = "Option::is_none")] + pub raw_message: Option, + /// Pre-formatted error message produced by Rust, matching the JS + /// formatCompilerError() output. When present, the JS shim uses this + /// directly instead of calling formatCompilerError() on the JS side. + #[serde(rename = "formattedMessage", skip_serializing_if = "Option::is_none")] + pub formatted_message: Option, +} + +/// Serializable error detail — flat plain object matching the TS +/// `formatDetailForLogging()` output. All fields are direct properties. +#[derive(Debug, Clone, Serialize)] +pub struct CompilerErrorDetailInfo { + pub category: String, + pub reason: String, + pub description: Option, + pub severity: String, + pub suggestions: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub loc: Option, +} + +/// Serializable suggestion info for logger events. +#[derive(Debug, Clone, Serialize)] +pub struct LoggerSuggestionInfo { + pub description: String, + pub op: LoggerSuggestionOp, + pub range: (usize, usize), + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, +} + +/// Numeric enum matching TS `CompilerSuggestionOperation`. +#[derive(Debug, Clone, Copy)] +pub enum LoggerSuggestionOp { + InsertBefore = 0, + InsertAfter = 1, + Remove = 2, + Replace = 3, +} + +impl serde::Serialize for LoggerSuggestionOp { + fn serialize(&self, serializer: S) -> Result { + serializer.serialize_u8(*self as u8) + } +} + +/// Individual error or hint item within a CompilerErrorDetailInfo. +#[derive(Debug, Clone, Serialize)] +pub struct CompilerErrorItemInfo { + pub kind: String, + pub loc: Option, + /// Serialized as `null` when None (not omitted), matching TS behavior. + pub message: Option, +} + +/// Debug log entry for debugLogIRs support. +/// Currently only supports the 'debug' variant (string values). +#[derive(Debug, Clone, Serialize)] +pub struct DebugLogEntry { + pub kind: &'static str, + pub name: String, + pub value: String, +} + +impl DebugLogEntry { + pub fn new(name: impl Into, value: impl Into) -> Self { + Self { + kind: "debug", + name: name.into(), + value: value.into(), + } + } +} + +/// Codegen output for a single compiled function. +/// Carries the generated AST fields needed to replace the original function. +#[derive(Debug, Clone)] +pub struct CodegenFunction { + pub loc: Option, + pub id: Option, + pub name_hint: Option, + pub params: Vec, + pub body: BlockStatement, + pub generator: bool, + pub is_async: bool, + pub memo_slots_used: u32, + pub memo_blocks: u32, + pub memo_values: u32, + pub pruned_memo_blocks: u32, + pub pruned_memo_values: u32, + pub outlined: Vec, +} + +/// An outlined function extracted during compilation. +#[derive(Debug, Clone)] +pub struct OutlinedFunction { + pub func: CodegenFunction, + pub fn_type: Option, +} + +/// Logger events emitted during compilation. +/// These are returned to JS for the logger callback. +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "kind")] +pub enum LoggerEvent { + CompileSuccess { + #[serde(rename = "fnLoc")] + fn_loc: Option, + #[serde(rename = "fnName")] + fn_name: Option, + #[serde(rename = "memoSlots")] + memo_slots: u32, + #[serde(rename = "memoBlocks")] + memo_blocks: u32, + #[serde(rename = "memoValues")] + memo_values: u32, + #[serde(rename = "prunedMemoBlocks")] + pruned_memo_blocks: u32, + #[serde(rename = "prunedMemoValues")] + pruned_memo_values: u32, + }, + CompileError { + detail: CompilerErrorDetailInfo, + #[serde(rename = "fnLoc")] + fn_loc: Option, + }, + /// Same as CompileError but serializes fnLoc before detail (matching TS + /// program.ts output) + #[serde(rename = "CompileError")] + CompileErrorWithLoc { + #[serde(rename = "fnLoc")] + fn_loc: LoggerSourceLocation, + detail: CompilerErrorDetailInfo, + }, + CompileSkip { + #[serde(rename = "fnLoc")] + fn_loc: Option, + reason: String, + #[serde(skip_serializing_if = "Option::is_none")] + loc: Option, + }, + CompileUnexpectedThrow { + #[serde(rename = "fnLoc")] + fn_loc: Option, + data: String, + }, + PipelineError { + #[serde(rename = "fnLoc")] + fn_loc: Option, + data: String, + }, +} diff --git a/crates/react_compiler/src/entrypoint/gating.rs b/crates/react_compiler/src/entrypoint/gating.rs new file mode 100644 index 000000000000..cd7e64ba6acb --- /dev/null +++ b/crates/react_compiler/src/entrypoint/gating.rs @@ -0,0 +1,578 @@ +// Gating rewrite logic for compiled functions. +// +// When gating is enabled, the compiled function is wrapped in a conditional: +// `gating() ? optimized_fn : original_fn` +// +// For function declarations referenced before their declaration, a special +// hoisting pattern is used (see `insert_additional_function_declaration`). +// +// Ported from `Entrypoint/Gating.ts`. + +use react_compiler_ast::{common::BaseNode, expressions::*, patterns::PatternLike, statements::*}; +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; + +use super::{imports::ProgramContext, plugin_options::GatingConfig}; + +/// A compiled function node, can be any function type. +#[derive(Debug, Clone)] +pub enum CompiledFunctionNode { + FunctionDeclaration(FunctionDeclaration), + FunctionExpression(FunctionExpression), + ArrowFunctionExpression(ArrowFunctionExpression), +} + +/// Represents a compiled function that needs gating. +/// In the Rust version, we work with indices into the program body +/// rather than Babel paths. +pub struct GatingRewrite { + /// Index in program.body where the original function is + pub original_index: usize, + /// The compiled function AST node + pub compiled_fn: CompiledFunctionNode, + /// The gating config + pub gating: GatingConfig, + /// Whether the function is referenced before its declaration at top level + pub referenced_before_declared: bool, + /// Whether the parent statement is an ExportDefaultDeclaration + pub is_export_default: bool, +} + +/// Apply gating rewrites to the program. +/// This modifies program.body by replacing/inserting statements. +/// +/// Corresponds to `insertGatedFunctionDeclaration` in the TS version, +/// but batched: all rewrites are collected first, then applied in reverse +/// index order to maintain validity of earlier indices. +pub fn apply_gating_rewrites( + program: &mut react_compiler_ast::Program, + mut rewrites: Vec, + context: &mut ProgramContext, +) -> Result<(), CompilerDiagnostic> { + // Sort rewrites in reverse order by original_index so that insertions + // at higher indices don't invalidate lower indices. + rewrites.sort_by(|a, b| b.original_index.cmp(&a.original_index)); + + for rewrite in rewrites { + let gating_imported_name = context + .add_import_specifier( + &rewrite.gating.source, + &rewrite.gating.import_specifier_name, + None, + ) + .name + .clone(); + + if rewrite.referenced_before_declared { + // The referenced-before-declared case only applies to FunctionDeclarations + if let CompiledFunctionNode::FunctionDeclaration(compiled) = rewrite.compiled_fn { + insert_additional_function_declaration( + &mut program.body, + rewrite.original_index, + compiled, + context, + &gating_imported_name, + )?; + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected compiled node type to match input type: got non-FunctionDeclaration \ + but expected FunctionDeclaration", + None, + )); + } + } else { + let original_stmt = program.body[rewrite.original_index].clone(); + let original_fn = extract_function_node_from_stmt(&original_stmt)?; + + let gating_expression = + build_gating_expression(rewrite.compiled_fn, original_fn, &gating_imported_name); + + // Determine how to rewrite based on context + if !rewrite.is_export_default { + if let Some(fn_name) = get_fn_decl_name(&original_stmt) { + // Convert function declaration to: const fnName = gating() ? compiled : + // original + let var_decl = Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::default(), + declarations: vec![VariableDeclarator { + base: BaseNode::default(), + id: PatternLike::Identifier(make_identifier(&fn_name)), + init: Some(Box::new(gating_expression)), + definite: None, + }], + kind: VariableDeclarationKind::Const, + declare: None, + }); + program.body[rewrite.original_index] = var_decl; + } else { + // Replace with the conditional expression directly (e.g. arrow/expression) + let expr_stmt = Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::default(), + expression: Box::new(gating_expression), + }); + program.body[rewrite.original_index] = expr_stmt; + } + } else { + // ExportDefaultDeclaration case + if let Some(fn_name) = get_fn_decl_name_from_export_default(&original_stmt) { + // Named export default function: replace with const + re-export + // const fnName = gating() ? compiled : original; + // export default fnName; + let var_decl = Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::default(), + declarations: vec![VariableDeclarator { + base: BaseNode::default(), + id: PatternLike::Identifier(make_identifier(&fn_name)), + init: Some(Box::new(gating_expression)), + definite: None, + }], + kind: VariableDeclarationKind::Const, + declare: None, + }); + let re_export = Statement::ExportDefaultDeclaration( + react_compiler_ast::declarations::ExportDefaultDeclaration { + base: BaseNode::default(), + declaration: Box::new( + react_compiler_ast::declarations::ExportDefaultDecl::Expression( + Box::new(Expression::Identifier(make_identifier(&fn_name))), + ), + ), + export_kind: None, + }, + ); + // Replace the original statement with the var decl, then insert re-export after + program.body[rewrite.original_index] = var_decl; + program.body.insert(rewrite.original_index + 1, re_export); + } else { + // Anonymous export default or arrow: replace the declaration content + // with the conditional expression + let export_default = Statement::ExportDefaultDeclaration( + react_compiler_ast::declarations::ExportDefaultDeclaration { + base: BaseNode::default(), + declaration: Box::new( + react_compiler_ast::declarations::ExportDefaultDecl::Expression( + Box::new(gating_expression), + ), + ), + export_kind: None, + }, + ); + program.body[rewrite.original_index] = export_default; + } + } + } + } + Ok(()) +} + +/// Gating rewrite for function declarations which are referenced before their +/// declaration site. +/// +/// ```js +/// // original +/// export default React.memo(Foo); +/// function Foo() { ... } +/// +/// // React compiler optimized + gated +/// import {gating} from 'myGating'; +/// export default React.memo(Foo); +/// const gating_result = gating(); // <- inserted +/// function Foo_optimized() {} // <- inserted +/// function Foo_unoptimized() {} // <- renamed from Foo +/// function Foo() { // <- inserted, hoistable by JS engines +/// if (gating_result) return Foo_optimized(); +/// else return Foo_unoptimized(); +/// } +/// ``` +fn insert_additional_function_declaration( + body: &mut Vec, + original_index: usize, + mut compiled: FunctionDeclaration, + context: &mut ProgramContext, + gating_function_identifier_name: &str, +) -> Result<(), CompilerDiagnostic> { + // Extract the original function declaration from body + let original_fn = match &body[original_index] { + Statement::FunctionDeclaration(fd) => fd.clone(), + Statement::ExportNamedDeclaration(end) => { + if let Some(decl) = &end.declaration { + if let react_compiler_ast::declarations::Declaration::FunctionDeclaration(fd) = + decl.as_ref() + { + fd.clone() + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function declaration in export", + None, + )); + } + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected declaration in export", + None, + )); + } + } + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function declaration at original_index", + None, + )); + } + }; + + let original_fn_name = original_fn + .id + .as_ref() + .expect("Expected function declaration referenced elsewhere to have a named identifier"); + let compiled_id = compiled + .id + .as_ref() + .expect("Expected compiled function declaration to have a named identifier"); + assert_eq!( + original_fn.params.len(), + compiled.params.len(), + "Expected compiled function to have the same number of parameters as source" + ); + + let _ = compiled_id; // used above for the assert + + // Generate unique names + let gating_condition_name = + context.new_uid(&format!("{}_result", gating_function_identifier_name)); + let unoptimized_fn_name = context.new_uid(&format!("{}_unoptimized", original_fn_name.name)); + let optimized_fn_name = context.new_uid(&format!("{}_optimized", original_fn_name.name)); + + // Step 1: rename existing functions + compiled.id = Some(make_identifier(&optimized_fn_name)); + + // Rename the original function in-place to *_unoptimized + rename_fn_decl_at(body, original_index, &unoptimized_fn_name)?; + + // Step 2: build new params and args for the dispatcher function + let mut new_params: Vec = Vec::new(); + let mut new_args_optimized: Vec = Vec::new(); + let mut new_args_unoptimized: Vec = Vec::new(); + + for (i, param) in original_fn.params.iter().enumerate() { + let arg_name = format!("arg{}", i); + match param { + PatternLike::RestElement(_) => { + new_params.push(PatternLike::RestElement( + react_compiler_ast::patterns::RestElement { + base: BaseNode::default(), + argument: Box::new(PatternLike::Identifier(make_identifier(&arg_name))), + type_annotation: None, + decorators: None, + }, + )); + new_args_optimized.push(Expression::SpreadElement(SpreadElement { + base: BaseNode::default(), + argument: Box::new(Expression::Identifier(make_identifier(&arg_name))), + })); + new_args_unoptimized.push(Expression::SpreadElement(SpreadElement { + base: BaseNode::default(), + argument: Box::new(Expression::Identifier(make_identifier(&arg_name))), + })); + } + _ => { + new_params.push(PatternLike::Identifier(make_identifier(&arg_name))); + new_args_optimized.push(Expression::Identifier(make_identifier(&arg_name))); + new_args_unoptimized.push(Expression::Identifier(make_identifier(&arg_name))); + } + } + } + + // Build the dispatcher function: + // function Foo(...args) { + // if (gating_result) return Foo_optimized(...args); + // else return Foo_unoptimized(...args); + // } + let dispatcher_fn = Statement::FunctionDeclaration(FunctionDeclaration { + base: BaseNode::default(), + id: Some(make_identifier(&original_fn_name.name)), + params: new_params, + body: BlockStatement { + base: BaseNode::default(), + body: vec![Statement::IfStatement(IfStatement { + base: BaseNode::default(), + test: Box::new(Expression::Identifier(make_identifier( + &gating_condition_name, + ))), + consequent: Box::new(Statement::ReturnStatement(ReturnStatement { + base: BaseNode::default(), + argument: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::default(), + callee: Box::new(Expression::Identifier(make_identifier( + &optimized_fn_name, + ))), + arguments: new_args_optimized, + type_parameters: None, + type_arguments: None, + optional: None, + }))), + })), + alternate: Some(Box::new(Statement::ReturnStatement(ReturnStatement { + base: BaseNode::default(), + argument: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::default(), + callee: Box::new(Expression::Identifier(make_identifier( + &unoptimized_fn_name, + ))), + arguments: new_args_unoptimized, + type_parameters: None, + type_arguments: None, + optional: None, + }))), + }))), + })], + directives: vec![], + }, + generator: false, + is_async: false, + declare: None, + return_type: None, + type_parameters: None, + predicate: None, + component_declaration: false, + hook_declaration: false, + }); + + // Build: const gating_result = gating(); + let gating_const = Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::default(), + declarations: vec![VariableDeclarator { + base: BaseNode::default(), + id: PatternLike::Identifier(make_identifier(&gating_condition_name)), + init: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::default(), + callee: Box::new(Expression::Identifier(make_identifier( + gating_function_identifier_name, + ))), + arguments: vec![], + type_parameters: None, + type_arguments: None, + optional: None, + }))), + definite: None, + }], + kind: VariableDeclarationKind::Const, + declare: None, + }); + + // Build: the compiled (optimized) function declaration + let compiled_stmt = Statement::FunctionDeclaration(compiled); + + // Insert statements. In the TS version: + // fnPath.insertBefore(gating_const) + // fnPath.insertBefore(compiled) + // fnPath.insertAfter(dispatcher_fn) + // + // This means the final order is: + // [before original_index]: gating_const + // [before original_index]: compiled (optimized fn) + // [at original_index]: original fn (renamed to *_unoptimized) + // [after original_index]: dispatcher fn + // + // We insert in order: first the ones before, then the one after. + // Insert before original_index: gating_const, compiled + body.insert(original_index, compiled_stmt); + body.insert(original_index, gating_const); + // The original (now renamed) fn is now at original_index + 2 + // Insert dispatcher after it + body.insert(original_index + 3, dispatcher_fn); + Ok(()) +} + +/// Build a gating conditional expression: +/// `gating_fn() ? build_fn_expr(compiled) : build_fn_expr(original)` +fn build_gating_expression( + compiled: CompiledFunctionNode, + original: CompiledFunctionNode, + gating_name: &str, +) -> Expression { + Expression::ConditionalExpression(ConditionalExpression { + base: BaseNode::default(), + test: Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::default(), + callee: Box::new(Expression::Identifier(make_identifier(gating_name))), + arguments: vec![], + type_parameters: None, + type_arguments: None, + optional: None, + })), + consequent: Box::new(build_function_expression(compiled)), + alternate: Box::new(build_function_expression(original)), + }) +} + +/// Convert a compiled function node to an expression. +/// Function declarations are converted to function expressions; +/// arrow functions and function expressions are returned as-is. +fn build_function_expression(node: CompiledFunctionNode) -> Expression { + match node { + CompiledFunctionNode::ArrowFunctionExpression(arrow) => { + Expression::ArrowFunctionExpression(arrow) + } + CompiledFunctionNode::FunctionExpression(func_expr) => { + Expression::FunctionExpression(func_expr) + } + CompiledFunctionNode::FunctionDeclaration(func_decl) => { + // Convert FunctionDeclaration to FunctionExpression + Expression::FunctionExpression(FunctionExpression { + base: func_decl.base, + params: func_decl.params, + body: func_decl.body, + id: func_decl.id, + generator: func_decl.generator, + is_async: func_decl.is_async, + return_type: func_decl.return_type, + type_parameters: func_decl.type_parameters, + }) + } + } +} + +/// Helper to create a simple Identifier with the given name and default +/// BaseNode. +fn make_identifier(name: &str) -> Identifier { + Identifier { + base: BaseNode::default(), + name: name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + } +} + +/// Extract the function name from a top-level Statement if it is a +/// FunctionDeclaration with an id. +fn get_fn_decl_name(stmt: &Statement) -> Option { + match stmt { + Statement::FunctionDeclaration(fd) => fd.id.as_ref().map(|id| id.name.clone()), + _ => None, + } +} + +/// Extract the function name from an ExportDefaultDeclaration's declaration, +/// if it is a named FunctionDeclaration. +fn get_fn_decl_name_from_export_default(stmt: &Statement) -> Option { + match stmt { + Statement::ExportDefaultDeclaration(ed) => match ed.declaration.as_ref() { + react_compiler_ast::declarations::ExportDefaultDecl::FunctionDeclaration(fd) => { + fd.id.as_ref().map(|id| id.name.clone()) + } + _ => None, + }, + _ => None, + } +} + +/// Extract a CompiledFunctionNode from a statement (for building the +/// "original" side of the gating expression). +fn extract_function_node_from_stmt( + stmt: &Statement, +) -> Result { + match stmt { + Statement::FunctionDeclaration(fd) => { + Ok(CompiledFunctionNode::FunctionDeclaration(fd.clone())) + } + Statement::ExpressionStatement(es) => match es.expression.as_ref() { + Expression::ArrowFunctionExpression(arrow) => { + Ok(CompiledFunctionNode::ArrowFunctionExpression(arrow.clone())) + } + Expression::FunctionExpression(fe) => { + Ok(CompiledFunctionNode::FunctionExpression(fe.clone())) + } + _ => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function expression in expression statement for gating", + None, + )), + }, + Statement::ExportDefaultDeclaration(ed) => match ed.declaration.as_ref() { + react_compiler_ast::declarations::ExportDefaultDecl::FunctionDeclaration(fd) => { + Ok(CompiledFunctionNode::FunctionDeclaration(fd.clone())) + } + react_compiler_ast::declarations::ExportDefaultDecl::Expression(expr) => { + match expr.as_ref() { + Expression::ArrowFunctionExpression(arrow) => { + Ok(CompiledFunctionNode::ArrowFunctionExpression(arrow.clone())) + } + Expression::FunctionExpression(fe) => { + Ok(CompiledFunctionNode::FunctionExpression(fe.clone())) + } + _ => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function expression in export default for gating", + None, + )), + } + } + _ => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function in export default declaration for gating", + None, + )), + }, + Statement::VariableDeclaration(vd) => { + let init = vd.declarations[0] + .init + .as_ref() + .expect("Expected variable declarator to have an init for gating"); + match init.as_ref() { + Expression::ArrowFunctionExpression(arrow) => { + Ok(CompiledFunctionNode::ArrowFunctionExpression(arrow.clone())) + } + Expression::FunctionExpression(fe) => { + Ok(CompiledFunctionNode::FunctionExpression(fe.clone())) + } + _ => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function expression in variable declaration for gating", + None, + )), + } + } + _ => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected statement type for gating rewrite", + None, + )), + } +} + +/// Rename the function declaration at `body[index]` in place. +/// Handles both bare FunctionDeclaration and ExportNamedDeclaration wrapping +/// one. +fn rename_fn_decl_at( + body: &mut [Statement], + index: usize, + new_name: &str, +) -> Result<(), CompilerDiagnostic> { + match &mut body[index] { + Statement::FunctionDeclaration(fd) => { + fd.id = Some(make_identifier(new_name)); + } + Statement::ExportNamedDeclaration(end) => { + if let Some(decl) = &mut end.declaration { + if let react_compiler_ast::declarations::Declaration::FunctionDeclaration(fd) = + decl.as_mut() + { + fd.id = Some(make_identifier(new_name)); + } + } + } + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function declaration to rename", + None, + )); + } + } + Ok(()) +} diff --git a/crates/react_compiler/src/entrypoint/imports.rs b/crates/react_compiler/src/entrypoint/imports.rs new file mode 100644 index 000000000000..ac0b8614191c --- /dev/null +++ b/crates/react_compiler/src/entrypoint/imports.rs @@ -0,0 +1,498 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +use std::collections::{HashMap, HashSet}; + +use react_compiler_ast::{ + common::BaseNode, + declarations::{ + ImportDeclaration, ImportKind, ImportSpecifier, ImportSpecifierData, ModuleExportName, + }, + expressions::{CallExpression, Expression, Identifier}, + literals::StringLiteral, + patterns::{ObjectPattern, ObjectPatternProp, ObjectPatternProperty, PatternLike}, + scope::ScopeInfo, + statements::{Statement, VariableDeclaration, VariableDeclarationKind, VariableDeclarator}, + Program, SourceType, +}; +use react_compiler_diagnostics::{ + CompilerError, CompilerErrorDetail, ErrorCategory, Position, SourceLocation, +}; + +use super::{ + compile_result::{DebugLogEntry, LoggerEvent, OrderedLogItem}, + plugin_options::{CompilerTarget, PluginOptions}, + suppression::SuppressionRange, +}; +use crate::timing::TimingData; + +/// An import specifier tracked by ProgramContext. +/// Corresponds to NonLocalImportSpecifier in the TS compiler. +#[derive(Debug, Clone)] +pub struct NonLocalImportSpecifier { + pub name: String, + pub module: String, + pub imported: String, +} + +/// Context for the program being compiled. +/// Tracks compiled functions, generated names, and import requirements. +/// Equivalent to ProgramContext class in Imports.ts. +pub struct ProgramContext { + pub opts: PluginOptions, + pub filename: Option, + /// The source filename from the parser's sourceFilename option. + /// This is the filename stored on AST node `loc.filename` fields, + /// which may differ from `filename` (e.g., no path prefix). + source_filename: Option, + pub code: Option, + pub react_runtime_module: String, + pub suppressions: Vec, + pub has_module_scope_opt_out: bool, + pub events: Vec, + /// Unified ordered log that interleaves events and debug entries + /// in the order they were emitted during compilation. + pub ordered_log: Vec, + + // Pre-resolved import local names for codegen + pub instrument_fn_name: Option, + pub instrument_gating_name: Option, + pub hook_guard_name: Option, + + // Variable renames from lowering, to be applied back to the Babel AST + pub renames: Vec, + + /// Timing data for profiling. Accumulates across all function compilations. + pub timing: TimingData, + + /// Whether debug logging is enabled (HIR formatting after each pass). + pub debug_enabled: bool, + + // Internal state + already_compiled: HashSet, + known_referenced_names: HashSet, + imports: HashMap>, +} + +impl ProgramContext { + pub fn new( + opts: PluginOptions, + filename: Option, + code: Option, + suppressions: Vec, + has_module_scope_opt_out: bool, + ) -> Self { + let react_runtime_module = get_react_compiler_runtime_module(&opts.target); + let profiling = opts.profiling; + let debug_enabled = opts.debug; + Self { + opts, + filename, + source_filename: None, + code, + react_runtime_module, + suppressions, + has_module_scope_opt_out, + events: Vec::new(), + ordered_log: Vec::new(), + instrument_fn_name: None, + instrument_gating_name: None, + hook_guard_name: None, + renames: Vec::new(), + timing: TimingData::new(profiling), + debug_enabled, + already_compiled: HashSet::new(), + known_referenced_names: HashSet::new(), + imports: HashMap::new(), + } + } + + /// Set the source filename (from AST node loc.filename). + pub fn set_source_filename(&mut self, filename: Option) { + if self.source_filename.is_none() { + self.source_filename = filename; + } + } + + /// Get the source filename for logger events. + pub fn source_filename(&self) -> Option { + self.source_filename.clone() + } + + /// Check if a function at the given start position has already been + /// compiled. This is a workaround for Babel not consistently respecting + /// skip(). + pub fn is_already_compiled(&self, start: u32) -> bool { + self.already_compiled.contains(&start) + } + + /// Mark a function at the given start position as compiled. + pub fn mark_compiled(&mut self, start: u32) { + self.already_compiled.insert(start); + } + + /// Initialize known referenced names from scope bindings. + /// Call this after construction to seed conflict detection with program + /// scope bindings. + pub fn init_from_scope(&mut self, scope: &ScopeInfo) { + // Register ALL bindings (not just program-scope) so that UID generation + // avoids name conflicts with any binding in the file. This matches + // Babel's generateUid() which checks all scopes. + for binding in &scope.bindings { + self.known_referenced_names.insert(binding.name.clone()); + } + } + + /// Check if a name conflicts with known references. + pub fn has_reference(&self, name: &str) -> bool { + self.known_referenced_names.contains(name) + } + + /// Generate a unique identifier name that doesn't conflict with existing + /// bindings. + /// + /// For hook names (use*), preserves the original name to avoid breaking + /// hook-name-based type inference. For other names, prefixes with + /// underscore similar to Babel's generateUid. + pub fn new_uid(&mut self, name: &str) -> String { + if is_hook_name(name) { + // Don't prefix hooks with underscore, since InferTypes might + // type HookKind based on callee naming convention. + let mut uid = name.to_string(); + let mut i = 0; + while self.has_reference(&uid) { + uid = format!("{}_{}", name, i); + i += 1; + } + self.known_referenced_names.insert(uid.clone()); + uid + } else if !self.has_reference(name) { + self.known_referenced_names.insert(name.to_string()); + name.to_string() + } else { + // Generate unique name with underscore prefix (similar to Babel's generateUid). + // Babel strips leading underscores before prefixing, so: + // generateUid("_c") → strips to "c" → generates "_c", "_c2", "_c3", ... + // generateUid("foo") → generates "_foo", "_foo2", "_foo3", ... + let base = name.trim_start_matches('_'); + let mut uid = format!("_{}", base); + let mut i = 2; + while self.has_reference(&uid) { + uid = format!("_{}{}", base, i); + i += 1; + } + self.known_referenced_names.insert(uid.clone()); + uid + } + } + + /// Add the memo cache import (the `c` function from the compiler runtime). + pub fn add_memo_cache_import(&mut self) -> NonLocalImportSpecifier { + let module = self.react_runtime_module.clone(); + self.add_import_specifier(&module, "c", Some("_c")) + } + + /// Add an import specifier, reusing an existing one if it was already + /// added. + /// + /// If `name_hint` is provided, it will be used as the basis for the local + /// name; otherwise `specifier` is used. + pub fn add_import_specifier( + &mut self, + module: &str, + specifier: &str, + name_hint: Option<&str>, + ) -> NonLocalImportSpecifier { + // Check if already imported + if let Some(module_imports) = self.imports.get(module) { + if let Some(existing) = module_imports.get(specifier) { + return existing.clone(); + } + } + + let name = self.new_uid(name_hint.unwrap_or(specifier)); + let binding = NonLocalImportSpecifier { + name, + module: module.to_string(), + imported: specifier.to_string(), + }; + + self.imports + .entry(module.to_string()) + .or_default() + .insert(specifier.to_string(), binding.clone()); + + binding + } + + /// Register a name as referenced so future uid generation avoids it. + pub fn add_new_reference(&mut self, name: String) { + self.known_referenced_names.insert(name); + } + + /// Log a compilation event. + pub fn log_event(&mut self, event: LoggerEvent) { + self.ordered_log.push(OrderedLogItem::Event { + event: event.clone(), + }); + self.events.push(event); + } + + /// Log a debug entry (for debugLogIRs support). + pub fn log_debug(&mut self, entry: DebugLogEntry) { + self.ordered_log.push(OrderedLogItem::Debug { entry }); + } + + /// Check if there are any pending imports to add to the program. + pub fn has_pending_imports(&self) -> bool { + !self.imports.is_empty() + } + + /// Get an immutable view of the generated imports. + pub fn imports(&self) -> &HashMap> { + &self.imports + } +} + +/// Check for blocklisted import modules. +/// Returns a CompilerError if any blocklisted imports are found. +pub fn validate_restricted_imports( + program: &Program, + blocklisted: &Option>, +) -> Option { + let blocklisted = match blocklisted { + Some(b) if !b.is_empty() => b, + _ => return None, + }; + let restricted: HashSet<&str> = blocklisted.iter().map(|s| s.as_str()).collect(); + let mut error = CompilerError::new(); + + for stmt in &program.body { + if let Statement::ImportDeclaration(import) = stmt { + if restricted.contains(import.source.value.as_str()) { + let mut detail = CompilerErrorDetail::new( + ErrorCategory::Todo, + "Bailing out due to blocklisted import", + ) + .with_description(format!("Import from module {}", import.source.value)); + detail.loc = import.base.loc.as_ref().map(|loc| SourceLocation { + start: Position { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: Position { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + }); + error.push_error_detail(detail); + } + } + } + + if error.has_any_errors() { + Some(error) + } else { + None + } +} + +/// Insert import declarations into the program body. +/// Handles both ESM imports and CommonJS require. +/// +/// For existing imports of the same module (non-namespaced, value imports), +/// new specifiers are merged into the existing declaration. Otherwise, +/// new import/require statements are prepended to the program body. +pub fn add_imports_to_program(program: &mut Program, context: &ProgramContext) { + if context.imports.is_empty() { + return; + } + + // Collect existing non-namespaced imports by module name + let existing_import_indices: HashMap = program + .body + .iter() + .enumerate() + .filter_map(|(idx, stmt)| { + if let Statement::ImportDeclaration(import) = stmt { + if is_non_namespaced_import(import) { + return Some((import.source.value.clone(), idx)); + } + } + None + }) + .collect(); + + let mut stmts: Vec = Vec::new(); + let mut sorted_modules: Vec<_> = context.imports.iter().collect(); + sorted_modules.sort_by(|(a, _), (b, _)| a.to_lowercase().cmp(&b.to_lowercase())); + + for (module_name, imports_map) in sorted_modules { + let sorted_imports = { + let mut sorted: Vec<_> = imports_map.values().collect(); + sorted.sort_by_key(|s| &s.imported); + sorted + }; + + let import_specifiers: Vec = sorted_imports + .iter() + .map(|spec| make_import_specifier(spec)) + .collect(); + + // If an existing import of this module exists, merge into it + if let Some(&idx) = existing_import_indices.get(module_name.as_str()) { + if let Statement::ImportDeclaration(ref mut import) = program.body[idx] { + import.specifiers.extend(import_specifiers); + } + } else if matches!(program.source_type, SourceType::Module) { + // ESM: import { ... } from 'module' + stmts.push(Statement::ImportDeclaration(ImportDeclaration { + base: BaseNode::typed("ImportDeclaration"), + specifiers: import_specifiers, + source: StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: module_name.clone(), + }, + import_kind: None, + assertions: None, + attributes: None, + })); + } else { + // CommonJS: const { imported: local, ... } = require('module') + let properties: Vec = sorted_imports + .iter() + .map(|spec| { + ObjectPatternProperty::ObjectProperty(ObjectPatternProp { + base: BaseNode::typed("ObjectProperty"), + key: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: spec.imported.clone(), + type_annotation: None, + optional: None, + decorators: None, + })), + value: Box::new(PatternLike::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: spec.name.clone(), + type_annotation: None, + optional: None, + decorators: None, + })), + computed: false, + shorthand: false, + decorators: None, + method: None, + }) + }) + .collect(); + + stmts.push(Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + kind: VariableDeclarationKind::Const, + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::ObjectPattern(ObjectPattern { + base: BaseNode::typed("ObjectPattern"), + properties, + type_annotation: None, + decorators: None, + }), + init: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: "require".to_string(), + type_annotation: None, + optional: None, + decorators: None, + })), + arguments: vec![Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: module_name.clone(), + })], + type_parameters: None, + type_arguments: None, + optional: None, + }))), + definite: None, + }], + declare: None, + })); + } + } + + // Prepend new import statements to the program body + if !stmts.is_empty() { + let mut new_body = stmts; + new_body.append(&mut program.body); + program.body = new_body; + } +} + +/// Create an ImportSpecifier AST node from a NonLocalImportSpecifier. +fn make_import_specifier(spec: &NonLocalImportSpecifier) -> ImportSpecifier { + ImportSpecifier::ImportSpecifier(ImportSpecifierData { + base: BaseNode::typed("ImportSpecifier"), + local: Identifier { + base: BaseNode::typed("Identifier"), + name: spec.name.clone(), + type_annotation: None, + optional: None, + decorators: None, + }, + imported: ModuleExportName::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: spec.imported.clone(), + type_annotation: None, + optional: None, + decorators: None, + }), + import_kind: None, + }) +} + +/// Check if an import declaration is a non-namespaced value import. +/// Matches `import { ... } from 'module'` but NOT: +/// - `import * as Foo from 'module'` (namespace) +/// - `import type { Foo } from 'module'` (type import) +/// - `import typeof { Foo } from 'module'` (typeof import) +fn is_non_namespaced_import(import: &ImportDeclaration) -> bool { + import + .specifiers + .iter() + .all(|s| matches!(s, ImportSpecifier::ImportSpecifier(_))) + && import + .import_kind + .as_ref() + .map_or(true, |k| matches!(k, ImportKind::Value)) +} + +/// Check if a name follows the React hook naming convention (use[A-Z0-9]...). +fn is_hook_name(name: &str) -> bool { + let bytes = name.as_bytes(); + bytes.len() >= 4 + && bytes[0] == b'u' + && bytes[1] == b's' + && bytes[2] == b'e' + && bytes + .get(3) + .map_or(false, |c| c.is_ascii_uppercase() || c.is_ascii_digit()) +} + +/// Get the runtime module name based on the compiler target. +pub fn get_react_compiler_runtime_module(target: &CompilerTarget) -> String { + match target { + CompilerTarget::Version(v) if v == "19" => "react/compiler-runtime".to_string(), + CompilerTarget::Version(v) if v == "17" || v == "18" => { + "react-compiler-runtime".to_string() + } + CompilerTarget::MetaInternal { runtime_module, .. } => runtime_module.clone(), + // Default to React 19 runtime for unrecognized versions + CompilerTarget::Version(_) => "react/compiler-runtime".to_string(), + } +} diff --git a/crates/react_compiler/src/entrypoint/mod.rs b/crates/react_compiler/src/entrypoint/mod.rs new file mode 100644 index 000000000000..4cb694371248 --- /dev/null +++ b/crates/react_compiler/src/entrypoint/mod.rs @@ -0,0 +1,12 @@ +pub mod compile_result; +pub mod gating; +pub mod imports; +pub mod pipeline; +pub mod plugin_options; +pub mod program; +pub mod suppression; +pub mod validate_source_locations; + +pub use compile_result::*; +pub use plugin_options::*; +pub use program::*; diff --git a/crates/react_compiler/src/entrypoint/pipeline.rs b/crates/react_compiler/src/entrypoint/pipeline.rs new file mode 100644 index 000000000000..86fdf5110851 --- /dev/null +++ b/crates/react_compiler/src/entrypoint/pipeline.rs @@ -0,0 +1,1893 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Compilation pipeline for a single function. +//! +//! Analogous to TS `Pipeline.ts` (`compileFn` → `run` → `runWithEnvironment`). +//! Currently runs BuildHIR (lowering) and PruneMaybeThrows. + +use react_compiler_ast::scope::ScopeInfo; +use react_compiler_diagnostics::CompilerError; +use react_compiler_hir::{ + environment::{Environment, OutputMode}, + environment_config::EnvironmentConfig, + ReactFunctionType, +}; +use react_compiler_lowering::FunctionNode; + +use super::{ + compile_result::{ + CodegenFunction, CompilerErrorDetailInfo, CompilerErrorItemInfo, DebugLogEntry, + LoggerPosition, LoggerSourceLocation, OutlinedFunction, + }, + imports::ProgramContext, + plugin_options::CompilerOutputMode, +}; +use crate::debug_print; + +/// Run the compilation pipeline on a single function. +/// +/// Currently: creates an Environment, runs BuildHIR (lowering), and produces +/// debug output via the context. Returns a CodegenFunction with zeroed memo +/// stats on success (codegen is not yet implemented). +pub fn compile_fn( + func: &FunctionNode<'_>, + fn_name: Option<&str>, + scope_info: &ScopeInfo, + fn_type: ReactFunctionType, + mode: CompilerOutputMode, + env_config: &EnvironmentConfig, + context: &mut ProgramContext, +) -> Result { + let mut env = Environment::with_config(env_config.clone()); + env.fn_type = fn_type; + env.output_mode = match mode { + CompilerOutputMode::Ssr => OutputMode::Ssr, + CompilerOutputMode::Client => OutputMode::Client, + CompilerOutputMode::Lint => OutputMode::Lint, + }; + env.code = context.code.clone(); + env.filename = context.filename.clone(); + env.instrument_fn_name = context.instrument_fn_name.clone(); + env.instrument_gating_name = context.instrument_gating_name.clone(); + env.hook_guard_name = context.hook_guard_name.clone(); + + context.timing.start("lower"); + let mut hir = react_compiler_lowering::lower(func, fn_name, scope_info, &mut env)?; + context.timing.stop(); + + // Collect any renames from lowering and pass to context + if !env.renames.is_empty() { + context.renames.extend(env.renames.drain(..)); + } + + // Check for Invariant errors after lowering, before logging HIR. + // In TS, Invariant errors throw from recordError(), aborting lower() before + // the HIR entry is logged. The thrown error contains ONLY the Invariant error, + // not other recorded (non-Invariant) errors. + if env.has_invariant_errors() { + return Err(env.take_invariant_errors()); + } + + if context.debug_enabled { + context.timing.start("debug_print:HIR"); + let debug_hir = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("HIR", debug_hir)); + context.timing.stop(); + } + + context.timing.start("PruneMaybeThrows"); + react_compiler_optimization::prune_maybe_throws(&mut hir, &mut env.functions)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneMaybeThrows"); + let debug_prune = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("PruneMaybeThrows", debug_prune)); + context.timing.stop(); + } + + context.timing.start("ValidateContextVariableLValues"); + react_compiler_validation::validate_context_variable_lvalues(&hir, &mut env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateContextVariableLValues", + "ok".to_string(), + )); + } + context.timing.stop(); + + context.timing.start("ValidateUseMemo"); + let void_memo_errors = react_compiler_validation::validate_use_memo(&hir, &mut env); + log_errors_as_events(&void_memo_errors, context); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new("ValidateUseMemo", "ok".to_string())); + } + context.timing.stop(); + + context.timing.start("DropManualMemoization"); + react_compiler_optimization::drop_manual_memoization(&mut hir, &mut env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:DropManualMemoization"); + let debug_drop_memo = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("DropManualMemoization", debug_drop_memo)); + context.timing.stop(); + } + + context + .timing + .start("InlineImmediatelyInvokedFunctionExpressions"); + react_compiler_optimization::inline_immediately_invoked_function_expressions( + &mut hir, &mut env, + ); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:InlineImmediatelyInvokedFunctionExpressions"); + let debug_inline_iifes = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "InlineImmediatelyInvokedFunctionExpressions", + debug_inline_iifes, + )); + context.timing.stop(); + } + + context.timing.start("MergeConsecutiveBlocks"); + react_compiler_optimization::merge_consecutive_blocks::merge_consecutive_blocks( + &mut hir, + &mut env.functions, + ); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:MergeConsecutiveBlocks"); + let debug_merge = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("MergeConsecutiveBlocks", debug_merge)); + context.timing.stop(); + } + + // TODO: port assertConsistentIdentifiers + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertConsistentIdentifiers", + "ok".to_string(), + )); + } + // TODO: port assertTerminalSuccessorsExist + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertTerminalSuccessorsExist", + "ok".to_string(), + )); + } + + context.timing.start("EnterSSA"); + react_compiler_ssa::enter_ssa(&mut hir, &mut env).map_err(|diag| { + let loc = diag.primary_location().cloned(); + let mut err = CompilerError::new(); + err.push_error_detail(react_compiler_diagnostics::CompilerErrorDetail { + category: diag.category, + reason: diag.reason, + description: diag.description, + loc, + suggestions: diag.suggestions, + }); + err + })?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:SSA"); + let debug_ssa = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("SSA", debug_ssa)); + context.timing.stop(); + } + + context.timing.start("EliminateRedundantPhi"); + react_compiler_ssa::eliminate_redundant_phi(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:EliminateRedundantPhi"); + let debug_eliminate_phi = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "EliminateRedundantPhi", + debug_eliminate_phi, + )); + context.timing.stop(); + } + + // TODO: port assertConsistentIdentifiers + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertConsistentIdentifiers", + "ok".to_string(), + )); + } + + context.timing.start("ConstantPropagation"); + react_compiler_optimization::constant_propagation(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:ConstantPropagation"); + let debug_const_prop = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("ConstantPropagation", debug_const_prop)); + context.timing.stop(); + } + + context.timing.start("InferTypes"); + react_compiler_typeinference::infer_types(&mut hir, &mut env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:InferTypes"); + let debug_infer_types = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("InferTypes", debug_infer_types)); + context.timing.stop(); + } + + if env.enable_validations() { + if env.config.validate_hooks_usage { + context.timing.start("ValidateHooksUsage"); + react_compiler_validation::validate_hooks_usage(&hir, &mut env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new("ValidateHooksUsage", "ok".to_string())); + } + context.timing.stop(); + } + + if env.config.validate_no_capitalized_calls.is_some() { + context.timing.start("ValidateNoCapitalizedCalls"); + react_compiler_validation::validate_no_capitalized_calls(&hir, &mut env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoCapitalizedCalls", + "ok".to_string(), + )); + } + context.timing.stop(); + } + } + + context.timing.start("OptimizePropsMethodCalls"); + react_compiler_optimization::optimize_props_method_calls(&mut hir, &env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:OptimizePropsMethodCalls"); + let debug_optimize_props = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "OptimizePropsMethodCalls", + debug_optimize_props, + )); + context.timing.stop(); + } + + context.timing.start("AnalyseFunctions"); + let mut inner_logs: Vec = Vec::new(); + let debug_inner = context.debug_enabled; + let analyse_result = react_compiler_inference::analyse_functions( + &mut hir, + &mut env, + &mut |inner_func, inner_env| { + if debug_inner { + inner_logs.push(debug_print::debug_hir(inner_func, inner_env)); + } + }, + ); + context.timing.stop(); + + // Always flush inner logs before propagating errors + if context.debug_enabled { + for inner_log in inner_logs { + context.log_debug(DebugLogEntry::new("AnalyseFunction (inner)", inner_log)); + } + } + + analyse_result?; + + if env.has_invariant_errors() { + return Err(env.take_invariant_errors()); + } + + if context.debug_enabled { + context.timing.start("debug_print:AnalyseFunctions"); + let debug_analyse_functions = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "AnalyseFunctions", + debug_analyse_functions, + )); + context.timing.stop(); + } + + context.timing.start("InferMutationAliasingEffects"); + react_compiler_inference::infer_mutation_aliasing_effects(&mut hir, &mut env, false)?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:InferMutationAliasingEffects"); + let debug_infer_effects = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "InferMutationAliasingEffects", + debug_infer_effects, + )); + context.timing.stop(); + } + + if env.output_mode == OutputMode::Ssr { + context.timing.start("OptimizeForSSR"); + react_compiler_optimization::optimize_for_ssr(&mut hir, &env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:OptimizeForSSR"); + let debug_ssr = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("OptimizeForSSR", debug_ssr)); + context.timing.stop(); + } + } + + context.timing.start("DeadCodeElimination"); + react_compiler_optimization::dead_code_elimination(&mut hir, &env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:DeadCodeElimination"); + let debug_dce = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("DeadCodeElimination", debug_dce)); + context.timing.stop(); + } + + context.timing.start("PruneMaybeThrows2"); + react_compiler_optimization::prune_maybe_throws(&mut hir, &mut env.functions)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneMaybeThrows2"); + let debug_prune2 = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("PruneMaybeThrows", debug_prune2)); + context.timing.stop(); + } + + context.timing.start("InferMutationAliasingRanges"); + react_compiler_inference::infer_mutation_aliasing_ranges(&mut hir, &mut env, false)?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:InferMutationAliasingRanges"); + let debug_infer_ranges = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "InferMutationAliasingRanges", + debug_infer_ranges, + )); + context.timing.stop(); + } + + if env.enable_validations() { + context + .timing + .start("ValidateLocalsNotReassignedAfterRender"); + react_compiler_validation::validate_locals_not_reassigned_after_render(&hir, &mut env); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateLocalsNotReassignedAfterRender", + "ok".to_string(), + )); + } + context.timing.stop(); + + if env.config.validate_ref_access_during_render { + context.timing.start("ValidateNoRefAccessInRender"); + react_compiler_validation::validate_no_ref_access_in_render(&hir, &mut env); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoRefAccessInRender", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + if env.config.validate_no_set_state_in_render { + context.timing.start("ValidateNoSetStateInRender"); + react_compiler_validation::validate_no_set_state_in_render(&hir, &mut env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoSetStateInRender", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + if env.config.validate_no_derived_computations_in_effects_exp + && env.output_mode == OutputMode::Lint + { + context + .timing + .start("ValidateNoDerivedComputationsInEffects"); + let errors = + react_compiler_validation::validate_no_derived_computations_in_effects_exp( + &hir, &env, + )?; + log_errors_as_events(&errors, context); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoDerivedComputationsInEffects", + "ok".to_string(), + )); + } + context.timing.stop(); + } else if env.config.validate_no_derived_computations_in_effects { + context + .timing + .start("ValidateNoDerivedComputationsInEffects"); + react_compiler_validation::validate_no_derived_computations_in_effects(&hir, &mut env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoDerivedComputationsInEffects", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + if env.config.validate_no_set_state_in_effects && env.output_mode == OutputMode::Lint { + context.timing.start("ValidateNoSetStateInEffects"); + let errors = react_compiler_validation::validate_no_set_state_in_effects(&hir, &env)?; + log_errors_as_events(&errors, context); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoSetStateInEffects", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + if env.config.validate_no_jsx_in_try_statements && env.output_mode == OutputMode::Lint { + context.timing.start("ValidateNoJSXInTryStatement"); + let errors = react_compiler_validation::validate_no_jsx_in_try_statement(&hir); + log_errors_as_events(&errors, context); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoJSXInTryStatement", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + context + .timing + .start("ValidateNoFreezingKnownMutableFunctions"); + react_compiler_validation::validate_no_freezing_known_mutable_functions(&hir, &mut env); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateNoFreezingKnownMutableFunctions", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + context.timing.start("InferReactivePlaces"); + react_compiler_inference::infer_reactive_places(&mut hir, &mut env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:InferReactivePlaces"); + let debug_reactive_places = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "InferReactivePlaces", + debug_reactive_places, + )); + context.timing.stop(); + } + + if env.enable_validations() { + context.timing.start("ValidateExhaustiveDependencies"); + react_compiler_validation::validate_exhaustive_dependencies(&mut hir, &mut env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateExhaustiveDependencies", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + context + .timing + .start("RewriteInstructionKindsBasedOnReassignment"); + react_compiler_ssa::rewrite_instruction_kinds_based_on_reassignment(&mut hir, &env)?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:RewriteInstructionKindsBasedOnReassignment"); + let debug_rewrite = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "RewriteInstructionKindsBasedOnReassignment", + debug_rewrite, + )); + context.timing.stop(); + } + + if env.enable_validations() + && env.config.validate_static_components + && env.output_mode == OutputMode::Lint + { + context.timing.start("ValidateStaticComponents"); + let errors = react_compiler_validation::validate_static_components(&hir); + log_errors_as_events(&errors, context); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidateStaticComponents", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + if env.enable_memoization() { + context.timing.start("InferReactiveScopeVariables"); + react_compiler_inference::infer_reactive_scope_variables(&mut hir, &mut env)?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:InferReactiveScopeVariables"); + let debug_infer_scopes = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "InferReactiveScopeVariables", + debug_infer_scopes, + )); + context.timing.stop(); + } + } + + context + .timing + .start("MemoizeFbtAndMacroOperandsInSameScope"); + let fbt_operands = + react_compiler_inference::memoize_fbt_and_macro_operands_in_same_scope(&hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:MemoizeFbtAndMacroOperandsInSameScope"); + let debug_fbt = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "MemoizeFbtAndMacroOperandsInSameScope", + debug_fbt, + )); + context.timing.stop(); + } + + if env.config.enable_jsx_outlining { + context.timing.start("OutlineJsx"); + react_compiler_optimization::outline_jsx(&mut hir, &mut env); + context.timing.stop(); + } + + if env.config.enable_name_anonymous_functions { + context.timing.start("NameAnonymousFunctions"); + react_compiler_optimization::name_anonymous_functions(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:NameAnonymousFunctions"); + let debug_name_anon = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "NameAnonymousFunctions", + debug_name_anon, + )); + context.timing.stop(); + } + } + + if env.config.enable_function_outlining { + context.timing.start("OutlineFunctions"); + react_compiler_optimization::outline_functions(&mut hir, &mut env, &fbt_operands); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:OutlineFunctions"); + let debug_outline = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("OutlineFunctions", debug_outline)); + context.timing.stop(); + } + } + + context.timing.start("AlignMethodCallScopes"); + react_compiler_inference::align_method_call_scopes(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:AlignMethodCallScopes"); + let debug_align = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new("AlignMethodCallScopes", debug_align)); + context.timing.stop(); + } + + context.timing.start("AlignObjectMethodScopes"); + react_compiler_inference::align_object_method_scopes(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:AlignObjectMethodScopes"); + let debug_align_obj = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "AlignObjectMethodScopes", + debug_align_obj, + )); + context.timing.stop(); + } + + context.timing.start("PruneUnusedLabelsHIR"); + react_compiler_optimization::prune_unused_labels_hir(&mut hir); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneUnusedLabelsHIR"); + let debug_prune_labels = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "PruneUnusedLabelsHIR", + debug_prune_labels, + )); + context.timing.stop(); + } + + context.timing.start("AlignReactiveScopesToBlockScopesHIR"); + react_compiler_inference::align_reactive_scopes_to_block_scopes_hir(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:AlignReactiveScopesToBlockScopesHIR"); + let debug_align_block_scopes = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "AlignReactiveScopesToBlockScopesHIR", + debug_align_block_scopes, + )); + context.timing.stop(); + } + + context.timing.start("MergeOverlappingReactiveScopesHIR"); + react_compiler_inference::merge_overlapping_reactive_scopes_hir(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:MergeOverlappingReactiveScopesHIR"); + let debug_merge_overlapping = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "MergeOverlappingReactiveScopesHIR", + debug_merge_overlapping, + )); + context.timing.stop(); + } + + // TODO: port assertValidBlockNesting + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertValidBlockNesting", + "ok".to_string(), + )); + } + + context.timing.start("BuildReactiveScopeTerminalsHIR"); + react_compiler_inference::build_reactive_scope_terminals_hir(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:BuildReactiveScopeTerminalsHIR"); + let debug_build_scope_terminals = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "BuildReactiveScopeTerminalsHIR", + debug_build_scope_terminals, + )); + context.timing.stop(); + } + + // TODO: port assertValidBlockNesting + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertValidBlockNesting", + "ok".to_string(), + )); + } + + context.timing.start("FlattenReactiveLoopsHIR"); + react_compiler_inference::flatten_reactive_loops_hir(&mut hir); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:FlattenReactiveLoopsHIR"); + let debug_flatten_loops = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "FlattenReactiveLoopsHIR", + debug_flatten_loops, + )); + context.timing.stop(); + } + + context.timing.start("FlattenScopesWithHooksOrUseHIR"); + react_compiler_inference::flatten_scopes_with_hooks_or_use_hir(&mut hir, &env)?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:FlattenScopesWithHooksOrUseHIR"); + let debug_flatten_hooks = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "FlattenScopesWithHooksOrUseHIR", + debug_flatten_hooks, + )); + context.timing.stop(); + } + + // TODO: port assertTerminalSuccessorsExist + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertTerminalSuccessorsExist", + "ok".to_string(), + )); + } + // TODO: port assertTerminalPredsExist + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertTerminalPredsExist", + "ok".to_string(), + )); + } + + context.timing.start("PropagateScopeDependenciesHIR"); + react_compiler_inference::propagate_scope_dependencies_hir(&mut hir, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:PropagateScopeDependenciesHIR"); + let debug_propagate_deps = debug_print::debug_hir(&hir, &env); + context.log_debug(DebugLogEntry::new( + "PropagateScopeDependenciesHIR", + debug_propagate_deps, + )); + context.timing.stop(); + } + + context.timing.start("BuildReactiveFunction"); + let mut reactive_fn = react_compiler_reactive_scopes::build_reactive_function(&hir, &env)?; + context.timing.stop(); + + let hir_formatter = |fmt: &mut react_compiler_hir::print::PrintFormatter, + func: &react_compiler_hir::HirFunction| { + debug_print::format_hir_function_into(fmt, func); + }; + + if context.debug_enabled { + context.timing.start("debug_print:BuildReactiveFunction"); + let debug_reactive = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("BuildReactiveFunction", debug_reactive)); + context.timing.stop(); + } + + context.timing.start("AssertWellFormedBreakTargets"); + react_compiler_reactive_scopes::assert_well_formed_break_targets(&reactive_fn, &env); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertWellFormedBreakTargets", + "ok".to_string(), + )); + } + context.timing.stop(); + + context.timing.start("PruneUnusedLabels"); + react_compiler_reactive_scopes::prune_unused_labels(&mut reactive_fn, &env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneUnusedLabels"); + let debug_prune_labels_reactive = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "PruneUnusedLabels", + debug_prune_labels_reactive, + )); + context.timing.stop(); + } + + context.timing.start("AssertScopeInstructionsWithinScopes"); + react_compiler_reactive_scopes::assert_scope_instructions_within_scopes(&reactive_fn, &env)?; + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "AssertScopeInstructionsWithinScopes", + "ok".to_string(), + )); + } + context.timing.stop(); + + context.timing.start("PruneNonEscapingScopes"); + react_compiler_reactive_scopes::prune_non_escaping_scopes(&mut reactive_fn, &mut env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneNonEscapingScopes"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("PruneNonEscapingScopes", debug)); + context.timing.stop(); + } + + context.timing.start("PruneNonReactiveDependencies"); + react_compiler_reactive_scopes::prune_non_reactive_dependencies(&mut reactive_fn, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:PruneNonReactiveDependencies"); + let debug_prune_non_reactive = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "PruneNonReactiveDependencies", + debug_prune_non_reactive, + )); + context.timing.stop(); + } + + context.timing.start("PruneUnusedScopes"); + react_compiler_reactive_scopes::prune_unused_scopes(&mut reactive_fn, &env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneUnusedScopes"); + let debug_prune_unused_scopes = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "PruneUnusedScopes", + debug_prune_unused_scopes, + )); + context.timing.stop(); + } + + context + .timing + .start("MergeReactiveScopesThatInvalidateTogether"); + react_compiler_reactive_scopes::merge_reactive_scopes_that_invalidate_together( + &mut reactive_fn, + &mut env, + )?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:MergeReactiveScopesThatInvalidateTogether"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "MergeReactiveScopesThatInvalidateTogether", + debug, + )); + context.timing.stop(); + } + + context.timing.start("PruneAlwaysInvalidatingScopes"); + react_compiler_reactive_scopes::prune_always_invalidating_scopes(&mut reactive_fn, &env)?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:PruneAlwaysInvalidatingScopes"); + let debug_prune_always_inv = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "PruneAlwaysInvalidatingScopes", + debug_prune_always_inv, + )); + context.timing.stop(); + } + + context.timing.start("PropagateEarlyReturns"); + react_compiler_reactive_scopes::propagate_early_returns(&mut reactive_fn, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PropagateEarlyReturns"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("PropagateEarlyReturns", debug)); + context.timing.stop(); + } + + context.timing.start("PruneUnusedLValues"); + react_compiler_reactive_scopes::prune_unused_lvalues(&mut reactive_fn, &env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneUnusedLValues"); + let debug_prune_lvalues = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "PruneUnusedLValues", + debug_prune_lvalues, + )); + context.timing.stop(); + } + + context.timing.start("PromoteUsedTemporaries"); + react_compiler_reactive_scopes::promote_used_temporaries(&mut reactive_fn, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PromoteUsedTemporaries"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("PromoteUsedTemporaries", debug)); + context.timing.stop(); + } + + context + .timing + .start("ExtractScopeDeclarationsFromDestructuring"); + react_compiler_reactive_scopes::extract_scope_declarations_from_destructuring( + &mut reactive_fn, + &mut env, + )?; + context.timing.stop(); + + if context.debug_enabled { + context + .timing + .start("debug_print:ExtractScopeDeclarationsFromDestructuring"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new( + "ExtractScopeDeclarationsFromDestructuring", + debug, + )); + context.timing.stop(); + } + + context.timing.start("StabilizeBlockIds"); + react_compiler_reactive_scopes::stabilize_block_ids(&mut reactive_fn, &mut env); + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:StabilizeBlockIds"); + let debug_stabilize = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("StabilizeBlockIds", debug_stabilize)); + context.timing.stop(); + } + + context.timing.start("RenameVariables"); + let unique_identifiers = + react_compiler_reactive_scopes::rename_variables(&mut reactive_fn, &mut env); + context.timing.stop(); + + for name in &unique_identifiers { + context.add_new_reference(name.clone()); + } + + if context.debug_enabled { + context.timing.start("debug_print:RenameVariables"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("RenameVariables", debug)); + context.timing.stop(); + } + + context.timing.start("PruneHoistedContexts"); + react_compiler_reactive_scopes::prune_hoisted_contexts(&mut reactive_fn, &mut env)?; + context.timing.stop(); + + if context.debug_enabled { + context.timing.start("debug_print:PruneHoistedContexts"); + let debug = react_compiler_reactive_scopes::print_reactive_function::debug_reactive_function_with_formatter( + &reactive_fn, &env, Some(&hir_formatter), + ); + context.log_debug(DebugLogEntry::new("PruneHoistedContexts", debug)); + context.timing.stop(); + } + + if env.config.enable_preserve_existing_memoization_guarantees + || env.config.validate_preserve_existing_memoization_guarantees + { + context.timing.start("ValidatePreservedManualMemoization"); + react_compiler_validation::validate_preserved_manual_memoization(&reactive_fn, &mut env); + if context.debug_enabled { + context.log_debug(DebugLogEntry::new( + "ValidatePreservedManualMemoization", + "ok".to_string(), + )); + } + context.timing.stop(); + } + + context.timing.start("codegen"); + let codegen_result = react_compiler_reactive_scopes::codegen_function( + &reactive_fn, + &mut env, + unique_identifiers, + fbt_operands, + )?; + context.timing.stop(); + + // Register the memo cache import as a side effect of codegen, matching TS + // behavior where addMemoCacheImport() is called during + // codegenReactiveFunction. This must happen BEFORE the env.has_errors() + // check so the import persists even when the pipeline returns Err (e.g., + // when validation errors are accumulated but codegen succeeded). + if codegen_result.memo_slots_used > 0 { + context.add_memo_cache_import(); + } + + if env.config.validate_source_locations { + super::validate_source_locations::validate_source_locations( + func, + &codegen_result, + &mut env, + ); + } + + // Simulate unexpected exception for testing (matches TS Pipeline.ts) + if env.config.throw_unknown_exception_testonly { + let mut err = CompilerError::new(); + err.push_error_detail(react_compiler_diagnostics::CompilerErrorDetail { + category: react_compiler_diagnostics::ErrorCategory::Invariant, + reason: "unexpected error".to_string(), + description: None, + loc: None, + suggestions: None, + }); + return Err(err); + } + + // Check for accumulated errors at the end of the pipeline + // (matches TS Pipeline.ts: env.hasErrors() → Err at the end) + if env.has_errors() { + return Err(env.take_errors()); + } + + // Re-compile outlined functions through the full pipeline. + // This mirrors TS behavior where outlined functions from JSX outlining + // are pushed back onto the compilation queue and compiled as components. + let mut compiled_outlined: Vec = Vec::new(); + for o in codegen_result.outlined { + let outlined_codegen = CodegenFunction { + loc: o.func.loc, + id: o.func.id, + name_hint: o.func.name_hint, + params: o.func.params, + body: o.func.body, + generator: o.func.generator, + is_async: o.func.is_async, + memo_slots_used: o.func.memo_slots_used, + memo_blocks: o.func.memo_blocks, + memo_values: o.func.memo_values, + pruned_memo_blocks: o.func.pruned_memo_blocks, + pruned_memo_values: o.func.pruned_memo_values, + outlined: Vec::new(), + }; + if let Some(fn_type) = o.fn_type { + let fn_name = outlined_codegen.id.as_ref().map(|id| id.name.clone()); + match compile_outlined_fn( + outlined_codegen, + fn_name.as_deref(), + fn_type, + mode, + env_config, + context, + ) { + Ok(compiled) => { + compiled_outlined.push(OutlinedFunction { + func: compiled, + fn_type: Some(fn_type), + }); + } + Err(_err) => { + // If re-compilation fails, skip the outlined function + } + } + } else { + compiled_outlined.push(OutlinedFunction { + func: outlined_codegen, + fn_type: o.fn_type, + }); + } + } + + Ok(CodegenFunction { + loc: codegen_result.loc, + id: codegen_result.id, + name_hint: codegen_result.name_hint, + params: codegen_result.params, + body: codegen_result.body, + generator: codegen_result.generator, + is_async: codegen_result.is_async, + memo_slots_used: codegen_result.memo_slots_used, + memo_blocks: codegen_result.memo_blocks, + memo_values: codegen_result.memo_values, + pruned_memo_blocks: codegen_result.pruned_memo_blocks, + pruned_memo_values: codegen_result.pruned_memo_values, + outlined: compiled_outlined, + }) +} + +/// Compile an outlined function's codegen AST through the full pipeline. +/// +/// Creates a fresh Environment, builds a synthetic ScopeInfo with unique fake +/// positions for identifier resolution, lowers from AST to HIR, then runs +/// the full compilation pipeline. This mirrors the TS behavior where outlined +/// functions are inserted into the program AST and re-compiled from scratch. +pub fn compile_outlined_fn( + mut codegen_fn: CodegenFunction, + fn_name: Option<&str>, + fn_type: ReactFunctionType, + mode: CompilerOutputMode, + env_config: &EnvironmentConfig, + context: &mut ProgramContext, +) -> Result { + let mut env = Environment::with_config(env_config.clone()); + env.fn_type = fn_type; + env.output_mode = match mode { + CompilerOutputMode::Ssr => OutputMode::Ssr, + CompilerOutputMode::Client => OutputMode::Client, + CompilerOutputMode::Lint => OutputMode::Lint, + }; + + // Build a FunctionDeclaration from the codegen output + let mut outlined_decl = react_compiler_ast::statements::FunctionDeclaration { + base: react_compiler_ast::common::BaseNode::typed("FunctionDeclaration"), + id: codegen_fn.id.take(), + params: std::mem::take(&mut codegen_fn.params), + body: std::mem::replace( + &mut codegen_fn.body, + react_compiler_ast::statements::BlockStatement { + base: react_compiler_ast::common::BaseNode::typed("BlockStatement"), + body: Vec::new(), + directives: Vec::new(), + }, + ), + generator: codegen_fn.generator, + is_async: codegen_fn.is_async, + declare: None, + return_type: None, + type_parameters: None, + predicate: None, + component_declaration: false, + hook_declaration: false, + }; + + // Build scope info by assigning fake positions to all identifiers + let scope_info = build_outlined_scope_info(&mut outlined_decl); + + let func_node = react_compiler_lowering::FunctionNode::FunctionDeclaration(&outlined_decl); + let mut hir = react_compiler_lowering::lower(&func_node, fn_name, &scope_info, &mut env)?; + + if env.has_invariant_errors() { + return Err(env.take_invariant_errors()); + } + + run_pipeline_passes(&mut hir, &mut env, context) +} + +/// Build a ScopeInfo for an outlined function declaration by assigning unique +/// fake positions to all Identifier nodes and building the binding/reference +/// maps. +fn build_outlined_scope_info( + func: &mut react_compiler_ast::statements::FunctionDeclaration, +) -> react_compiler_ast::scope::ScopeInfo { + use std::collections::HashMap; + + use react_compiler_ast::scope::*; + + let mut pos: u32 = 1; // reserve 0 for the function itself + func.base.start = Some(0); + + let mut fn_bindings: HashMap = HashMap::new(); + let mut bindings_list: Vec = Vec::new(); + let mut ref_to_binding: indexmap::IndexMap = indexmap::IndexMap::new(); + + // Helper to add a binding + let _add_binding = + |name: &str, + kind: BindingKind, + p: u32, + fn_bindings: &mut HashMap, + bindings_list: &mut Vec, + ref_to_binding: &mut indexmap::IndexMap| { + if fn_bindings.contains_key(name) { + // Already exists, just add reference + let bid = fn_bindings[name]; + ref_to_binding.insert(p, bid); + return; + } + let binding_id = BindingId(bindings_list.len() as u32); + fn_bindings.insert(name.to_string(), binding_id); + bindings_list.push(BindingData { + id: binding_id, + name: name.to_string(), + kind, + scope: ScopeId(1), + declaration_type: "VariableDeclarator".to_string(), + declaration_start: Some(p), + import: None, + }); + ref_to_binding.insert(p, binding_id); + }; + + // Process params - add as Param bindings + for param in &mut func.params { + outlined_assign_pattern_positions( + param, + &mut pos, + BindingKind::Param, + &mut fn_bindings, + &mut bindings_list, + &mut ref_to_binding, + ); + } + + // Process body - walk all statements to assign positions and collect variable + // declarations + for stmt in &mut func.body.body { + outlined_assign_stmt_positions( + stmt, + &mut pos, + &mut fn_bindings, + &mut bindings_list, + &mut ref_to_binding, + ); + } + + let program_scope = ScopeData { + id: ScopeId(0), + parent: None, + kind: ScopeKind::Program, + bindings: HashMap::new(), + }; + let fn_scope = ScopeData { + id: ScopeId(1), + parent: Some(ScopeId(0)), + kind: ScopeKind::Function, + bindings: fn_bindings, + }; + + let mut node_to_scope: HashMap = HashMap::new(); + node_to_scope.insert(0, ScopeId(1)); + + ScopeInfo { + scopes: vec![program_scope, fn_scope], + bindings: bindings_list, + node_to_scope, + reference_to_binding: ref_to_binding, + program_scope: ScopeId(0), + } +} + +/// Assign positions to identifiers in a pattern and register as bindings. +fn outlined_assign_pattern_positions( + pattern: &mut react_compiler_ast::patterns::PatternLike, + pos: &mut u32, + kind: react_compiler_ast::scope::BindingKind, + fn_bindings: &mut std::collections::HashMap, + bindings_list: &mut Vec, + ref_to_binding: &mut indexmap::IndexMap, +) { + use react_compiler_ast::{patterns::PatternLike, scope::*}; + + match pattern { + PatternLike::Identifier(id) => { + let p = *pos; + *pos += 1; + id.base.start = Some(p); + // Add as a binding + if !fn_bindings.contains_key(&id.name) { + let binding_id = BindingId(bindings_list.len() as u32); + fn_bindings.insert(id.name.clone(), binding_id); + bindings_list.push(BindingData { + id: binding_id, + name: id.name.clone(), + kind: kind.clone(), + scope: ScopeId(1), + declaration_type: "VariableDeclarator".to_string(), + declaration_start: Some(p), + import: None, + }); + ref_to_binding.insert(p, binding_id); + } else { + let bid = fn_bindings[&id.name]; + ref_to_binding.insert(p, bid); + } + } + PatternLike::ObjectPattern(obj) => { + for prop in &mut obj.properties { + match prop { + react_compiler_ast::patterns::ObjectPatternProperty::ObjectProperty( + p_inner, + ) => { + outlined_assign_pattern_positions( + &mut p_inner.value, + pos, + kind.clone(), + fn_bindings, + bindings_list, + ref_to_binding, + ); + } + react_compiler_ast::patterns::ObjectPatternProperty::RestElement(r) => { + outlined_assign_pattern_positions( + &mut r.argument, + pos, + kind.clone(), + fn_bindings, + bindings_list, + ref_to_binding, + ); + } + } + } + } + PatternLike::ArrayPattern(arr) => { + for elem in arr.elements.iter_mut().flatten() { + outlined_assign_pattern_positions( + elem, + pos, + kind.clone(), + fn_bindings, + bindings_list, + ref_to_binding, + ); + } + } + PatternLike::AssignmentPattern(assign) => { + outlined_assign_pattern_positions( + &mut assign.left, + pos, + kind.clone(), + fn_bindings, + bindings_list, + ref_to_binding, + ); + } + PatternLike::RestElement(rest) => { + outlined_assign_pattern_positions( + &mut rest.argument, + pos, + kind.clone(), + fn_bindings, + bindings_list, + ref_to_binding, + ); + } + _ => {} + } +} + +/// Assign positions to identifiers in a statement body. +fn outlined_assign_stmt_positions( + stmt: &mut react_compiler_ast::statements::Statement, + pos: &mut u32, + fn_bindings: &mut std::collections::HashMap, + bindings_list: &mut Vec, + ref_to_binding: &mut indexmap::IndexMap, +) { + use react_compiler_ast::statements::Statement; + + match stmt { + Statement::VariableDeclaration(decl) => { + for declarator in &mut decl.declarations { + // Process init first (references) + if let Some(init) = &mut declarator.init { + outlined_assign_expr_positions(init, pos, fn_bindings, ref_to_binding); + } + // Process pattern (declarations) + outlined_assign_pattern_positions( + &mut declarator.id, + pos, + react_compiler_ast::scope::BindingKind::Let, + fn_bindings, + bindings_list, + ref_to_binding, + ); + } + } + Statement::ReturnStatement(ret) => { + if let Some(arg) = &mut ret.argument { + outlined_assign_expr_positions(arg, pos, fn_bindings, ref_to_binding); + } + } + Statement::ExpressionStatement(expr_stmt) => { + outlined_assign_expr_positions( + &mut expr_stmt.expression, + pos, + fn_bindings, + ref_to_binding, + ); + } + _ => {} + } +} + +/// Assign positions to identifiers in an expression. +fn outlined_assign_expr_positions( + expr: &mut react_compiler_ast::expressions::Expression, + pos: &mut u32, + fn_bindings: &std::collections::HashMap, + ref_to_binding: &mut indexmap::IndexMap, +) { + use react_compiler_ast::expressions::*; + + match expr { + Expression::Identifier(id) => { + let p = *pos; + *pos += 1; + id.base.start = Some(p); + if let Some(&bid) = fn_bindings.get(&id.name) { + ref_to_binding.insert(p, bid); + } + } + Expression::JSXElement(jsx) => { + // Opening tag + outlined_assign_jsx_name_positions( + &mut jsx.opening_element.name, + pos, + fn_bindings, + ref_to_binding, + ); + for attr in &mut jsx.opening_element.attributes { + match attr { + react_compiler_ast::jsx::JSXAttributeItem::JSXAttribute(a) => { + if let Some(val) = &mut a.value { + outlined_assign_jsx_val_positions( + val, + pos, + fn_bindings, + ref_to_binding, + ); + } + } + react_compiler_ast::jsx::JSXAttributeItem::JSXSpreadAttribute(s) => { + outlined_assign_expr_positions( + &mut s.argument, + pos, + fn_bindings, + ref_to_binding, + ); + } + } + } + for child in &mut jsx.children { + outlined_assign_jsx_child_positions(child, pos, fn_bindings, ref_to_binding); + } + } + Expression::JSXFragment(frag) => { + for child in &mut frag.children { + outlined_assign_jsx_child_positions(child, pos, fn_bindings, ref_to_binding); + } + } + _ => {} + } +} + +fn outlined_assign_jsx_name_positions( + name: &mut react_compiler_ast::jsx::JSXElementName, + pos: &mut u32, + fn_bindings: &std::collections::HashMap, + ref_to_binding: &mut indexmap::IndexMap, +) { + match name { + react_compiler_ast::jsx::JSXElementName::JSXIdentifier(id) => { + let p = *pos; + *pos += 1; + id.base.start = Some(p); + if let Some(&bid) = fn_bindings.get(&id.name) { + ref_to_binding.insert(p, bid); + } + } + react_compiler_ast::jsx::JSXElementName::JSXMemberExpression(m) => { + outlined_assign_jsx_member_positions(m, pos, fn_bindings, ref_to_binding); + } + _ => {} + } +} + +fn outlined_assign_jsx_member_positions( + member: &mut react_compiler_ast::jsx::JSXMemberExpression, + pos: &mut u32, + fn_bindings: &std::collections::HashMap, + ref_to_binding: &mut indexmap::IndexMap, +) { + match &mut *member.object { + react_compiler_ast::jsx::JSXMemberExprObject::JSXIdentifier(id) => { + let p = *pos; + *pos += 1; + id.base.start = Some(p); + if let Some(&bid) = fn_bindings.get(&id.name) { + ref_to_binding.insert(p, bid); + } + } + react_compiler_ast::jsx::JSXMemberExprObject::JSXMemberExpression(inner) => { + outlined_assign_jsx_member_positions(inner, pos, fn_bindings, ref_to_binding); + } + } +} + +fn outlined_assign_jsx_val_positions( + val: &mut react_compiler_ast::jsx::JSXAttributeValue, + pos: &mut u32, + fn_bindings: &std::collections::HashMap, + ref_to_binding: &mut indexmap::IndexMap, +) { + match val { + react_compiler_ast::jsx::JSXAttributeValue::JSXExpressionContainer(c) => { + if let react_compiler_ast::jsx::JSXExpressionContainerExpr::Expression(e) = + &mut c.expression + { + outlined_assign_expr_positions(e, pos, fn_bindings, ref_to_binding); + } + } + react_compiler_ast::jsx::JSXAttributeValue::JSXElement(el) => { + let mut expr = react_compiler_ast::expressions::Expression::JSXElement(el.clone()); + outlined_assign_expr_positions(&mut expr, pos, fn_bindings, ref_to_binding); + if let react_compiler_ast::expressions::Expression::JSXElement(new_el) = expr { + **el = *new_el; + } + } + _ => {} + } +} + +fn outlined_assign_jsx_child_positions( + child: &mut react_compiler_ast::jsx::JSXChild, + pos: &mut u32, + fn_bindings: &std::collections::HashMap, + ref_to_binding: &mut indexmap::IndexMap, +) { + match child { + react_compiler_ast::jsx::JSXChild::JSXExpressionContainer(c) => { + if let react_compiler_ast::jsx::JSXExpressionContainerExpr::Expression(e) = + &mut c.expression + { + outlined_assign_expr_positions(e, pos, fn_bindings, ref_to_binding); + } + } + react_compiler_ast::jsx::JSXChild::JSXElement(el) => { + let mut expr = + react_compiler_ast::expressions::Expression::JSXElement(Box::new(*el.clone())); + outlined_assign_expr_positions(&mut expr, pos, fn_bindings, ref_to_binding); + if let react_compiler_ast::expressions::Expression::JSXElement(new_el) = expr { + **el = *new_el; + } + } + react_compiler_ast::jsx::JSXChild::JSXFragment(frag) => { + for inner in &mut frag.children { + outlined_assign_jsx_child_positions(inner, pos, fn_bindings, ref_to_binding); + } + } + _ => {} + } +} +// end of outlined function helpers + +/// Run the compilation pipeline passes on an HIR function (everything after +/// lowering). +/// +/// This is extracted from `compile_fn` to allow reuse for outlined functions. +/// Returns the compiled CodegenFunction on success. +fn run_pipeline_passes( + hir: &mut react_compiler_hir::HirFunction, + env: &mut Environment, + context: &mut ProgramContext, +) -> Result { + react_compiler_optimization::prune_maybe_throws(hir, &mut env.functions)?; + + react_compiler_optimization::drop_manual_memoization(hir, env)?; + + react_compiler_optimization::inline_immediately_invoked_function_expressions(hir, env); + + react_compiler_optimization::merge_consecutive_blocks::merge_consecutive_blocks( + hir, + &mut env.functions, + ); + + react_compiler_ssa::enter_ssa(hir, env).map_err(|diag| { + let loc = diag.primary_location().cloned(); + let mut err = CompilerError::new(); + err.push_error_detail(react_compiler_diagnostics::CompilerErrorDetail { + category: diag.category, + reason: diag.reason, + description: diag.description, + loc, + suggestions: diag.suggestions, + }); + err + })?; + + react_compiler_ssa::eliminate_redundant_phi(hir, env); + + react_compiler_optimization::constant_propagation(hir, env); + + react_compiler_typeinference::infer_types(hir, env)?; + + if env.enable_validations() { + if env.config.validate_hooks_usage { + react_compiler_validation::validate_hooks_usage(hir, env)?; + } + } + + react_compiler_optimization::optimize_props_method_calls(hir, env); + + react_compiler_inference::analyse_functions(hir, env, &mut |_inner_func, _inner_env| {})?; + + if env.has_invariant_errors() { + return Err(env.take_invariant_errors()); + } + + react_compiler_inference::infer_mutation_aliasing_effects(hir, env, false)?; + + if env.output_mode == OutputMode::Ssr { + react_compiler_optimization::optimize_for_ssr(hir, env); + } + + react_compiler_optimization::dead_code_elimination(hir, env); + + react_compiler_optimization::prune_maybe_throws(hir, &mut env.functions)?; + + react_compiler_inference::infer_mutation_aliasing_ranges(hir, env, false)?; + + if env.enable_validations() { + react_compiler_validation::validate_locals_not_reassigned_after_render(hir, env); + + if env.config.validate_ref_access_during_render { + react_compiler_validation::validate_no_ref_access_in_render(hir, env); + } + + if env.config.validate_no_set_state_in_render { + react_compiler_validation::validate_no_set_state_in_render(hir, env)?; + } + + react_compiler_validation::validate_no_freezing_known_mutable_functions(hir, env); + } + + react_compiler_inference::infer_reactive_places(hir, env)?; + + if env.enable_validations() { + react_compiler_validation::validate_exhaustive_dependencies(hir, env)?; + } + + react_compiler_ssa::rewrite_instruction_kinds_based_on_reassignment(hir, env)?; + + if env.enable_memoization() { + react_compiler_inference::infer_reactive_scope_variables(hir, env)?; + } + + let fbt_operands = + react_compiler_inference::memoize_fbt_and_macro_operands_in_same_scope(hir, env); + + // Don't run outline_jsx on outlined functions (they're already outlined) + + if env.config.enable_name_anonymous_functions { + react_compiler_optimization::name_anonymous_functions(hir, env); + } + + if env.config.enable_function_outlining { + react_compiler_optimization::outline_functions(hir, env, &fbt_operands); + } + + react_compiler_inference::align_method_call_scopes(hir, env); + react_compiler_inference::align_object_method_scopes(hir, env); + + react_compiler_optimization::prune_unused_labels_hir(hir); + + react_compiler_inference::align_reactive_scopes_to_block_scopes_hir(hir, env); + react_compiler_inference::merge_overlapping_reactive_scopes_hir(hir, env); + + react_compiler_inference::build_reactive_scope_terminals_hir(hir, env); + react_compiler_inference::flatten_reactive_loops_hir(hir); + react_compiler_inference::flatten_scopes_with_hooks_or_use_hir(hir, env)?; + react_compiler_inference::propagate_scope_dependencies_hir(hir, env); + let mut reactive_fn = react_compiler_reactive_scopes::build_reactive_function(hir, env)?; + + react_compiler_reactive_scopes::assert_well_formed_break_targets(&reactive_fn, env); + + react_compiler_reactive_scopes::prune_unused_labels(&mut reactive_fn, env)?; + + react_compiler_reactive_scopes::assert_scope_instructions_within_scopes(&reactive_fn, env)?; + + react_compiler_reactive_scopes::prune_non_escaping_scopes(&mut reactive_fn, env)?; + react_compiler_reactive_scopes::prune_non_reactive_dependencies(&mut reactive_fn, env); + react_compiler_reactive_scopes::prune_unused_scopes(&mut reactive_fn, env)?; + react_compiler_reactive_scopes::merge_reactive_scopes_that_invalidate_together( + &mut reactive_fn, + env, + )?; + react_compiler_reactive_scopes::prune_always_invalidating_scopes(&mut reactive_fn, env)?; + react_compiler_reactive_scopes::propagate_early_returns(&mut reactive_fn, env); + react_compiler_reactive_scopes::prune_unused_lvalues(&mut reactive_fn, env); + react_compiler_reactive_scopes::promote_used_temporaries(&mut reactive_fn, env); + react_compiler_reactive_scopes::extract_scope_declarations_from_destructuring( + &mut reactive_fn, + env, + )?; + react_compiler_reactive_scopes::stabilize_block_ids(&mut reactive_fn, env); + + let unique_identifiers = + react_compiler_reactive_scopes::rename_variables(&mut reactive_fn, env); + for name in &unique_identifiers { + context.add_new_reference(name.clone()); + } + + react_compiler_reactive_scopes::prune_hoisted_contexts(&mut reactive_fn, env)?; + + if env.config.enable_preserve_existing_memoization_guarantees + || env.config.validate_preserve_existing_memoization_guarantees + { + react_compiler_validation::validate_preserved_manual_memoization(&reactive_fn, env); + } + + let codegen_result = react_compiler_reactive_scopes::codegen_function( + &reactive_fn, + env, + unique_identifiers, + fbt_operands, + )?; + + Ok(CodegenFunction { + loc: codegen_result.loc, + id: codegen_result.id, + name_hint: codegen_result.name_hint, + params: codegen_result.params, + body: codegen_result.body, + generator: codegen_result.generator, + is_async: codegen_result.is_async, + memo_slots_used: codegen_result.memo_slots_used, + memo_blocks: codegen_result.memo_blocks, + memo_values: codegen_result.memo_values, + pruned_memo_blocks: codegen_result.pruned_memo_blocks, + pruned_memo_values: codegen_result.pruned_memo_values, + outlined: codegen_result + .outlined + .into_iter() + .map(|o| OutlinedFunction { + func: CodegenFunction { + loc: o.func.loc, + id: o.func.id, + name_hint: o.func.name_hint, + params: o.func.params, + body: o.func.body, + generator: o.func.generator, + is_async: o.func.is_async, + memo_slots_used: o.func.memo_slots_used, + memo_blocks: o.func.memo_blocks, + memo_values: o.func.memo_values, + pruned_memo_blocks: o.func.pruned_memo_blocks, + pruned_memo_values: o.func.pruned_memo_values, + outlined: Vec::new(), + }, + fn_type: o.fn_type, + }) + .collect(), + }) +} + +/// Log CompilerError diagnostics as CompileError events, matching TS +/// `env.logErrors()` behavior. These are logged for telemetry/lint output but +/// not accumulated as compile errors. +fn log_errors_as_events(errors: &CompilerError, context: &mut ProgramContext) { + // Use the source_filename from the AST (set by parser's sourceFilename option). + // This is stored on the Environment during lowering. + let source_filename = context.source_filename(); + for detail in &errors.details { + let detail_info = match detail { + react_compiler_diagnostics::CompilerErrorOrDiagnostic::Diagnostic(d) => { + let items: Option> = { + let v: Vec = d + .details + .iter() + .map(|item| match item { + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc, + message, + identifier_name, + } => CompilerErrorItemInfo { + kind: "error".to_string(), + loc: loc.as_ref().map(|l| LoggerSourceLocation { + start: LoggerPosition { + line: l.start.line, + column: l.start.column, + index: l.start.index, + }, + end: LoggerPosition { + line: l.end.line, + column: l.end.column, + index: l.end.index, + }, + filename: source_filename.clone(), + identifier_name: identifier_name.clone(), + }), + message: message.clone(), + }, + react_compiler_diagnostics::CompilerDiagnosticDetail::Hint { + message, + } => CompilerErrorItemInfo { + kind: "hint".to_string(), + loc: None, + message: Some(message.clone()), + }, + }) + .collect(); + if v.is_empty() { + None + } else { + Some(v) + } + }; + CompilerErrorDetailInfo { + category: format!("{:?}", d.category), + reason: d.reason.clone(), + description: d.description.clone(), + severity: format!("{:?}", d.logged_severity()), + suggestions: None, + details: items, + loc: None, + } + } + react_compiler_diagnostics::CompilerErrorOrDiagnostic::ErrorDetail(d) => { + CompilerErrorDetailInfo { + category: format!("{:?}", d.category), + reason: d.reason.clone(), + description: d.description.clone(), + severity: format!("{:?}", d.logged_severity()), + suggestions: None, + details: None, + loc: None, + } + } + }; + context.log_event(super::compile_result::LoggerEvent::CompileError { + fn_loc: None, + detail: detail_info, + }); + } +} diff --git a/crates/react_compiler/src/entrypoint/plugin_options.rs b/crates/react_compiler/src/entrypoint/plugin_options.rs new file mode 100644 index 000000000000..d335699a309c --- /dev/null +++ b/crates/react_compiler/src/entrypoint/plugin_options.rs @@ -0,0 +1,120 @@ +use react_compiler_hir::environment_config::EnvironmentConfig; +use serde::{Deserialize, Serialize}; + +/// Target configuration for the compiler +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum CompilerTarget { + /// Standard React version target + Version(String), // "17", "18", "19" + /// Meta-internal target with custom runtime module + MetaInternal { + kind: String, // "donotuse_meta_internal" + #[serde(rename = "runtimeModule")] + runtime_module: String, + }, +} + +/// Gating configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GatingConfig { + pub source: String, + #[serde(rename = "importSpecifierName")] + pub import_specifier_name: String, +} + +/// Dynamic gating configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DynamicGatingConfig { + pub source: String, +} + +/// Serializable plugin options, pre-resolved by the JS shim. +/// JS-only values (sources function, logger, etc.) are resolved before +/// being sent to Rust. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PluginOptions { + // Pre-resolved by JS + pub should_compile: bool, + pub enable_reanimated: bool, + pub is_dev: bool, + pub filename: Option, + + // Pass-through options + #[serde(default = "default_compilation_mode")] + pub compilation_mode: String, + #[serde(default = "default_panic_threshold")] + pub panic_threshold: String, + #[serde(default = "default_target")] + pub target: CompilerTarget, + #[serde(default)] + pub gating: Option, + #[serde(default)] + pub dynamic_gating: Option, + #[serde(default)] + pub no_emit: bool, + #[serde(default)] + pub output_mode: Option, + #[serde(default)] + pub eslint_suppression_rules: Option>, + #[serde(default = "default_true")] + pub flow_suppressions: bool, + #[serde(default)] + pub ignore_use_no_forget: bool, + #[serde(default)] + pub custom_opt_out_directives: Option>, + #[serde(default)] + pub environment: EnvironmentConfig, + + /// Source code of the file being compiled (passed from Babel plugin for + /// fast refresh hash). + #[serde(default, rename = "__sourceCode")] + pub source_code: Option, + + /// Enable profiling timing data collection. + #[serde(default, rename = "__profiling")] + pub profiling: bool, + + /// Enable debug logging (HIR formatting after each pass). + /// Only set to true when a logger with debugLogIRs is configured on the JS + /// side. + #[serde(default, rename = "__debug")] + pub debug: bool, +} + +fn default_compilation_mode() -> String { + "infer".to_string() +} + +fn default_panic_threshold() -> String { + "none".to_string() +} + +fn default_target() -> CompilerTarget { + CompilerTarget::Version("19".to_string()) +} + +fn default_true() -> bool { + true +} + +/// Output mode for the compiler, derived from PluginOptions. +/// Matches the TS `compilerOutputMode` logic in Program.ts. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CompilerOutputMode { + Ssr, + Client, + Lint, +} + +impl CompilerOutputMode { + pub fn from_opts(opts: &PluginOptions) -> Self { + match opts.output_mode.as_deref() { + Some("ssr") => Self::Ssr, + Some("lint") => Self::Lint, + _ if opts.no_emit => Self::Lint, + _ => Self::Client, + } + } +} diff --git a/crates/react_compiler/src/entrypoint/program.rs b/crates/react_compiler/src/entrypoint/program.rs new file mode 100644 index 000000000000..e8abe2b6f0bd --- /dev/null +++ b/crates/react_compiler/src/entrypoint/program.rs @@ -0,0 +1,3794 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Main entrypoint for the React Compiler. +//! +//! This module is a port of Program.ts from the TypeScript compiler. It +//! orchestrates the compilation of a program by: +//! 1. Checking if compilation should be skipped +//! 2. Validating restricted imports +//! 3. Finding program-level suppressions +//! 4. Discovering functions to compile (components, hooks) +//! 5. Processing each function through the compilation pipeline +//! 6. Applying compiled functions back to the AST + +use std::collections::{HashMap, HashSet}; + +use react_compiler_ast::{ + common::BaseNode, + declarations::{ + Declaration, ExportDefaultDecl, ExportDefaultDeclaration, ImportSpecifier, ModuleExportName, + }, + expressions::*, + patterns::PatternLike, + scope::{ScopeId, ScopeInfo}, + statements::*, + visitor::{walk_program_mut, AstWalker, MutVisitor, VisitResult, Visitor}, + File, Program, +}; +use react_compiler_diagnostics::{ + CompilerError, CompilerErrorDetail, CompilerErrorOrDiagnostic, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{environment_config::EnvironmentConfig, ReactFunctionType}; +use react_compiler_lowering::FunctionNode; +use regex::Regex; + +use super::{ + compile_result::{ + BindingRenameInfo, CodegenFunction, CompileResult, CompilerErrorDetailInfo, + CompilerErrorInfo, CompilerErrorItemInfo, DebugLogEntry, LoggerEvent, LoggerPosition, + LoggerSourceLocation, LoggerSuggestionInfo, LoggerSuggestionOp, OrderedLogItem, + }, + imports::{ + add_imports_to_program, get_react_compiler_runtime_module, validate_restricted_imports, + ProgramContext, + }, + pipeline, + plugin_options::{CompilerOutputMode, GatingConfig, PluginOptions}, + suppression::{ + filter_suppressions_that_affect_function, find_program_suppressions, + suppressions_to_compiler_error, SuppressionRange, + }, +}; + +// ----------------------------------------------------------------------- +// Constants +// ----------------------------------------------------------------------- + +const DEFAULT_ESLINT_SUPPRESSIONS: &[&str] = + &["react-hooks/exhaustive-deps", "react-hooks/rules-of-hooks"]; + +/// Directives that opt a function into memoization +const OPT_IN_DIRECTIVES: &[&str] = &["use forget", "use memo"]; + +/// Directives that opt a function out of memoization +const OPT_OUT_DIRECTIVES: &[&str] = &["use no forget", "use no memo"]; + +// ----------------------------------------------------------------------- +// Internal types +// ----------------------------------------------------------------------- + +/// A function found in the program that should be compiled +#[allow(dead_code)] +struct CompileSource<'a> { + kind: CompileSourceKind, + fn_node: FunctionNode<'a>, + /// Location of this function in the AST for logging + fn_name: Option, + fn_loc: Option, + /// Original AST source location (with index and filename) for logger + /// events. + fn_ast_loc: Option, + fn_start: Option, + fn_end: Option, + fn_type: ReactFunctionType, + /// Directives from the function body (for opt-in/opt-out checks) + body_directives: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CompileSourceKind { + Original, + #[allow(dead_code)] + Outlined, +} + +// ----------------------------------------------------------------------- +// Directive helpers +// ----------------------------------------------------------------------- + +/// Check if any opt-in directive is present in the given directives. +/// Returns the first matching directive, or None. +/// +/// Also checks for dynamic gating directives (`use memo if(...)`) +fn try_find_directive_enabling_memoization<'a>( + directives: &'a [Directive], + opts: &PluginOptions, +) -> Result, CompilerError> { + // Check standard opt-in directives + let opt_in = directives + .iter() + .find(|d| OPT_IN_DIRECTIVES.contains(&d.value.value.as_str())); + if let Some(directive) = opt_in { + return Ok(Some(directive)); + } + + // Check dynamic gating directives + match find_directives_dynamic_gating(directives, opts) { + Ok(Some(result)) => Ok(Some(result.directive)), + Ok(None) => Ok(None), + Err(e) => Err(e), + } +} + +/// Check if any opt-out directive is present in the given directives. +fn find_directive_disabling_memoization<'a>( + directives: &'a [Directive], + opts: &PluginOptions, +) -> Option<&'a Directive> { + if let Some(ref custom_directives) = opts.custom_opt_out_directives { + directives + .iter() + .find(|d| custom_directives.contains(&d.value.value)) + } else { + directives + .iter() + .find(|d| OPT_OUT_DIRECTIVES.contains(&d.value.value.as_str())) + } +} + +/// Result of a dynamic gating directive parse. +struct DynamicGatingResult<'a> { + #[allow(dead_code)] + directive: &'a Directive, + gating: GatingConfig, +} + +/// Check for dynamic gating directives like `use memo if(identifier)`. +/// Returns the directive and gating config if found, or an error if malformed. +fn find_directives_dynamic_gating<'a>( + directives: &'a [Directive], + opts: &PluginOptions, +) -> Result>, CompilerError> { + let dynamic_gating = match &opts.dynamic_gating { + Some(dg) => dg, + None => return Ok(None), + }; + + let pattern = Regex::new(r"^use memo if\(([^\)]*)\)$").expect("Invalid dynamic gating regex"); + + let mut errors: Vec = Vec::new(); + let mut matches: Vec<(&'a Directive, String)> = Vec::new(); + + for directive in directives { + if let Some(caps) = pattern.captures(&directive.value.value) { + if let Some(m) = caps.get(1) { + let ident = m.as_str(); + if is_valid_identifier(ident) { + matches.push((directive, ident.to_string())); + } else { + let mut detail = CompilerErrorDetail::new( + ErrorCategory::Gating, + "Dynamic gating directive is not a valid JavaScript identifier", + ) + .with_description(format!("Found '{}'", directive.value.value)); + detail.loc = directive.base.loc.as_ref().map(convert_loc); + errors.push(detail); + } + } + } + } + + if !errors.is_empty() { + let mut err = CompilerError::new(); + for e in errors { + err.push_error_detail(e); + } + return Err(err); + } + + if matches.len() > 1 { + let names: Vec = matches.iter().map(|(d, _)| d.value.value.clone()).collect(); + let mut err = CompilerError::new(); + let mut detail = CompilerErrorDetail::new( + ErrorCategory::Gating, + "Multiple dynamic gating directives found", + ) + .with_description(format!( + "Expected a single directive but found [{}]", + names.join(", ") + )); + detail.loc = matches[0].0.base.loc.as_ref().map(convert_loc); + err.push_error_detail(detail); + return Err(err); + } + + if matches.len() == 1 { + Ok(Some(DynamicGatingResult { + directive: matches[0].0, + gating: GatingConfig { + source: dynamic_gating.source.clone(), + import_specifier_name: matches[0].1.clone(), + }, + })) + } else { + Ok(None) + } +} + +/// Simple check for valid JavaScript identifier (alphanumeric + underscore + $, +/// starting with letter/$/_ ) Also rejects reserved words like `true`, `false`, +/// `null`, etc. +fn is_valid_identifier(s: &str) -> bool { + if s.is_empty() { + return false; + } + let mut chars = s.chars(); + let first = chars.next().unwrap(); + if !first.is_alphabetic() && first != '_' && first != '$' { + return false; + } + if !chars.all(|c| c.is_alphanumeric() || c == '_' || c == '$') { + return false; + } + // Check for reserved words (matching Babel's t.isValidIdentifier) + !matches!( + s, + "break" + | "case" + | "catch" + | "continue" + | "debugger" + | "default" + | "do" + | "else" + | "finally" + | "for" + | "function" + | "if" + | "in" + | "instanceof" + | "new" + | "return" + | "switch" + | "this" + | "throw" + | "try" + | "typeof" + | "var" + | "void" + | "while" + | "with" + | "class" + | "const" + | "enum" + | "export" + | "extends" + | "import" + | "super" + | "implements" + | "interface" + | "let" + | "package" + | "private" + | "protected" + | "public" + | "static" + | "yield" + | "null" + | "true" + | "false" + | "delete" + ) +} + +// ----------------------------------------------------------------------- +// Name helpers +// ----------------------------------------------------------------------- + +/// Check if a string follows the React hook naming convention (use[A-Z0-9]...). +fn is_hook_name(s: &str) -> bool { + let bytes = s.as_bytes(); + bytes.len() >= 4 + && bytes[0] == b'u' + && bytes[1] == b's' + && bytes[2] == b'e' + && bytes + .get(3) + .map_or(false, |c| c.is_ascii_uppercase() || c.is_ascii_digit()) +} + +/// Check if a name looks like a React component (starts with uppercase letter). +fn is_component_name(name: &str) -> bool { + name.chars() + .next() + .map_or(false, |c| c.is_ascii_uppercase()) +} + +/// Check if an expression is a hook call (identifier with hook name, or +/// member expression `PascalCase.useHook`). +fn expr_is_hook(expr: &Expression) -> bool { + match expr { + Expression::Identifier(id) => is_hook_name(&id.name), + Expression::MemberExpression(member) => { + if member.computed { + return false; + } + // Property must be a hook name + if !expr_is_hook(&member.property) { + return false; + } + // Object must be a PascalCase identifier + if let Expression::Identifier(obj) = member.object.as_ref() { + obj.name + .chars() + .next() + .map_or(false, |c| c.is_ascii_uppercase()) + } else { + false + } + } + _ => false, + } +} + +/// Check if an expression is a React API call (e.g., `forwardRef` or +/// `React.forwardRef`). +#[allow(dead_code)] +fn is_react_api(expr: &Expression, function_name: &str) -> bool { + match expr { + Expression::Identifier(id) => id.name == function_name, + Expression::MemberExpression(member) => { + if let Expression::Identifier(obj) = member.object.as_ref() { + if obj.name == "React" { + if let Expression::Identifier(prop) = member.property.as_ref() { + return prop.name == function_name; + } + } + } + false + } + _ => false, + } +} + +/// Get the inferred function name from a function's context. +/// +/// For FunctionDeclaration: uses the `id` field. +/// For FunctionExpression/ArrowFunctionExpression: infers from parent context +/// (VariableDeclarator, etc.) which is passed explicitly since we don't have +/// Babel paths. +fn get_function_name_from_id(id: Option<&Identifier>) -> Option { + id.map(|id| id.name.clone()) +} + +// ----------------------------------------------------------------------- +// AST traversal helpers +// ----------------------------------------------------------------------- + +/// Check if an expression is a "non-node" return value (indicating the function +/// is not a React component). This matches the TS `isNonNode` function. +fn is_non_node(expr: &Expression) -> bool { + matches!( + expr, + Expression::ObjectExpression(_) + | Expression::ArrowFunctionExpression(_) + | Expression::FunctionExpression(_) + | Expression::BigIntLiteral(_) + | Expression::ClassExpression(_) + | Expression::NewExpression(_) + ) +} + +/// Recursively check if a function body returns a non-React-node value. +/// Walks all return statements in the function (not in nested functions). +fn returns_non_node_in_stmts(stmts: &[Statement]) -> bool { + for stmt in stmts { + if returns_non_node_in_stmt(stmt) { + return true; + } + } + false +} + +fn returns_non_node_in_stmt(stmt: &Statement) -> bool { + match stmt { + Statement::ReturnStatement(ret) => { + if let Some(ref arg) = ret.argument { + return is_non_node(arg); + } + false + } + Statement::BlockStatement(block) => returns_non_node_in_stmts(&block.body), + Statement::IfStatement(if_stmt) => { + returns_non_node_in_stmt(&if_stmt.consequent) + || if_stmt + .alternate + .as_ref() + .map_or(false, |alt| returns_non_node_in_stmt(alt)) + } + Statement::ForStatement(for_stmt) => returns_non_node_in_stmt(&for_stmt.body), + Statement::WhileStatement(while_stmt) => returns_non_node_in_stmt(&while_stmt.body), + Statement::DoWhileStatement(do_while) => returns_non_node_in_stmt(&do_while.body), + Statement::ForInStatement(for_in) => returns_non_node_in_stmt(&for_in.body), + Statement::ForOfStatement(for_of) => returns_non_node_in_stmt(&for_of.body), + Statement::SwitchStatement(switch) => { + for case in &switch.cases { + if returns_non_node_in_stmts(&case.consequent) { + return true; + } + } + false + } + Statement::TryStatement(try_stmt) => { + if returns_non_node_in_stmts(&try_stmt.block.body) { + return true; + } + if let Some(ref handler) = try_stmt.handler { + if returns_non_node_in_stmts(&handler.body.body) { + return true; + } + } + if let Some(ref finalizer) = try_stmt.finalizer { + if returns_non_node_in_stmts(&finalizer.body) { + return true; + } + } + false + } + Statement::LabeledStatement(labeled) => returns_non_node_in_stmt(&labeled.body), + Statement::WithStatement(with) => returns_non_node_in_stmt(&with.body), + // Skip nested function/class declarations -- they have their own returns + Statement::FunctionDeclaration(_) | Statement::ClassDeclaration(_) => false, + _ => false, + } +} + +/// Check if a function returns non-node values. +/// For arrow functions with expression body, checks the expression directly. +/// For block bodies, walks the statements. +fn returns_non_node_fn(params: &[PatternLike], body: &FunctionBody) -> bool { + let _ = params; + match body { + FunctionBody::Block(block) => returns_non_node_in_stmts(&block.body), + FunctionBody::Expression(expr) => is_non_node(expr), + } +} + +/// Check if a function body calls hooks or creates JSX. +/// Traverses the function body (not nested functions) looking for: +/// - CallExpression where callee is a hook +/// - JSXElement or JSXFragment +fn calls_hooks_or_creates_jsx_in_stmts(stmts: &[Statement]) -> bool { + for stmt in stmts { + if calls_hooks_or_creates_jsx_in_stmt(stmt) { + return true; + } + } + false +} + +fn calls_hooks_or_creates_jsx_in_stmt(stmt: &Statement) -> bool { + match stmt { + Statement::ExpressionStatement(expr_stmt) => { + calls_hooks_or_creates_jsx_in_expr(&expr_stmt.expression) + } + Statement::ReturnStatement(ret) => { + if let Some(ref arg) = ret.argument { + calls_hooks_or_creates_jsx_in_expr(arg) + } else { + false + } + } + Statement::VariableDeclaration(var_decl) => { + for decl in &var_decl.declarations { + if let Some(ref init) = decl.init { + if calls_hooks_or_creates_jsx_in_expr(init) { + return true; + } + } + } + false + } + Statement::BlockStatement(block) => calls_hooks_or_creates_jsx_in_stmts(&block.body), + Statement::IfStatement(if_stmt) => { + calls_hooks_or_creates_jsx_in_expr(&if_stmt.test) + || calls_hooks_or_creates_jsx_in_stmt(&if_stmt.consequent) + || if_stmt + .alternate + .as_ref() + .map_or(false, |alt| calls_hooks_or_creates_jsx_in_stmt(alt)) + } + Statement::ForStatement(for_stmt) => { + if let Some(ref init) = for_stmt.init { + match init.as_ref() { + ForInit::Expression(expr) => { + if calls_hooks_or_creates_jsx_in_expr(expr) { + return true; + } + } + ForInit::VariableDeclaration(var_decl) => { + for decl in &var_decl.declarations { + if let Some(ref init) = decl.init { + if calls_hooks_or_creates_jsx_in_expr(init) { + return true; + } + } + } + } + } + } + if let Some(ref test) = for_stmt.test { + if calls_hooks_or_creates_jsx_in_expr(test) { + return true; + } + } + if let Some(ref update) = for_stmt.update { + if calls_hooks_or_creates_jsx_in_expr(update) { + return true; + } + } + calls_hooks_or_creates_jsx_in_stmt(&for_stmt.body) + } + Statement::WhileStatement(while_stmt) => { + calls_hooks_or_creates_jsx_in_expr(&while_stmt.test) + || calls_hooks_or_creates_jsx_in_stmt(&while_stmt.body) + } + Statement::DoWhileStatement(do_while) => { + calls_hooks_or_creates_jsx_in_stmt(&do_while.body) + || calls_hooks_or_creates_jsx_in_expr(&do_while.test) + } + Statement::ForInStatement(for_in) => { + calls_hooks_or_creates_jsx_in_expr(&for_in.right) + || calls_hooks_or_creates_jsx_in_stmt(&for_in.body) + } + Statement::ForOfStatement(for_of) => { + calls_hooks_or_creates_jsx_in_expr(&for_of.right) + || calls_hooks_or_creates_jsx_in_stmt(&for_of.body) + } + Statement::SwitchStatement(switch) => { + if calls_hooks_or_creates_jsx_in_expr(&switch.discriminant) { + return true; + } + for case in &switch.cases { + if let Some(ref test) = case.test { + if calls_hooks_or_creates_jsx_in_expr(test) { + return true; + } + } + if calls_hooks_or_creates_jsx_in_stmts(&case.consequent) { + return true; + } + } + false + } + Statement::ThrowStatement(throw) => calls_hooks_or_creates_jsx_in_expr(&throw.argument), + Statement::TryStatement(try_stmt) => { + if calls_hooks_or_creates_jsx_in_stmts(&try_stmt.block.body) { + return true; + } + if let Some(ref handler) = try_stmt.handler { + if calls_hooks_or_creates_jsx_in_stmts(&handler.body.body) { + return true; + } + } + if let Some(ref finalizer) = try_stmt.finalizer { + if calls_hooks_or_creates_jsx_in_stmts(&finalizer.body) { + return true; + } + } + false + } + Statement::LabeledStatement(labeled) => calls_hooks_or_creates_jsx_in_stmt(&labeled.body), + Statement::WithStatement(with) => { + calls_hooks_or_creates_jsx_in_expr(&with.object) + || calls_hooks_or_creates_jsx_in_stmt(&with.body) + } + // Skip nested function/class declarations + Statement::FunctionDeclaration(_) | Statement::ClassDeclaration(_) => false, + _ => false, + } +} + +fn calls_hooks_or_creates_jsx_in_expr(expr: &Expression) -> bool { + match expr { + // JSX creates + Expression::JSXElement(_) | Expression::JSXFragment(_) => true, + + // Hook calls + Expression::CallExpression(call) => { + if expr_is_hook(&call.callee) { + return true; + } + // Also check arguments for JSX/hooks (but not nested functions) + if calls_hooks_or_creates_jsx_in_expr(&call.callee) { + return true; + } + for arg in &call.arguments { + // Skip function arguments -- they are nested functions + if matches!( + arg, + Expression::ArrowFunctionExpression(_) | Expression::FunctionExpression(_) + ) { + continue; + } + if calls_hooks_or_creates_jsx_in_expr(arg) { + return true; + } + } + false + } + Expression::OptionalCallExpression(call) => { + // Note: OptionalCallExpression is NOT treated as a hook call for + // the purpose of determining function type. The TS code only checks + // regular CallExpression nodes in callsHooksOrCreatesJsx. + // We still recurse into the callee and arguments to find other + // hook calls or JSX. + if calls_hooks_or_creates_jsx_in_expr(&call.callee) { + return true; + } + for arg in &call.arguments { + if matches!( + arg, + Expression::ArrowFunctionExpression(_) | Expression::FunctionExpression(_) + ) { + continue; + } + if calls_hooks_or_creates_jsx_in_expr(arg) { + return true; + } + } + false + } + + // Binary/logical + Expression::BinaryExpression(bin) => { + calls_hooks_or_creates_jsx_in_expr(&bin.left) + || calls_hooks_or_creates_jsx_in_expr(&bin.right) + } + Expression::LogicalExpression(log) => { + calls_hooks_or_creates_jsx_in_expr(&log.left) + || calls_hooks_or_creates_jsx_in_expr(&log.right) + } + Expression::ConditionalExpression(cond) => { + calls_hooks_or_creates_jsx_in_expr(&cond.test) + || calls_hooks_or_creates_jsx_in_expr(&cond.consequent) + || calls_hooks_or_creates_jsx_in_expr(&cond.alternate) + } + Expression::AssignmentExpression(assign) => { + calls_hooks_or_creates_jsx_in_expr(&assign.right) + } + Expression::SequenceExpression(seq) => seq + .expressions + .iter() + .any(|e| calls_hooks_or_creates_jsx_in_expr(e)), + Expression::UnaryExpression(unary) => calls_hooks_or_creates_jsx_in_expr(&unary.argument), + Expression::UpdateExpression(update) => { + calls_hooks_or_creates_jsx_in_expr(&update.argument) + } + Expression::MemberExpression(member) => { + calls_hooks_or_creates_jsx_in_expr(&member.object) + || calls_hooks_or_creates_jsx_in_expr(&member.property) + } + Expression::OptionalMemberExpression(member) => { + calls_hooks_or_creates_jsx_in_expr(&member.object) + || calls_hooks_or_creates_jsx_in_expr(&member.property) + } + Expression::SpreadElement(spread) => calls_hooks_or_creates_jsx_in_expr(&spread.argument), + Expression::AwaitExpression(await_expr) => { + calls_hooks_or_creates_jsx_in_expr(&await_expr.argument) + } + Expression::YieldExpression(yield_expr) => yield_expr + .argument + .as_ref() + .map_or(false, |arg| calls_hooks_or_creates_jsx_in_expr(arg)), + Expression::TaggedTemplateExpression(tagged) => { + calls_hooks_or_creates_jsx_in_expr(&tagged.tag) + } + Expression::TemplateLiteral(tl) => tl + .expressions + .iter() + .any(|e| calls_hooks_or_creates_jsx_in_expr(e)), + Expression::ArrayExpression(arr) => arr.elements.iter().any(|e| { + e.as_ref() + .map_or(false, |e| calls_hooks_or_creates_jsx_in_expr(e)) + }), + Expression::ObjectExpression(obj) => obj.properties.iter().any(|prop| match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + calls_hooks_or_creates_jsx_in_expr(&p.value) + } + ObjectExpressionProperty::SpreadElement(s) => { + calls_hooks_or_creates_jsx_in_expr(&s.argument) + } + // ObjectMethod: traverse into its body to find hooks/JSX. + // This matches the TS behavior where Babel's traverse enters + // ObjectMethod (only FunctionDeclaration, FunctionExpression, + // and ArrowFunctionExpression are skipped). + ObjectExpressionProperty::ObjectMethod(m) => { + calls_hooks_or_creates_jsx_in_stmts(&m.body.body) + } + }), + Expression::ParenthesizedExpression(paren) => { + calls_hooks_or_creates_jsx_in_expr(&paren.expression) + } + Expression::TSAsExpression(ts) => calls_hooks_or_creates_jsx_in_expr(&ts.expression), + Expression::TSSatisfiesExpression(ts) => calls_hooks_or_creates_jsx_in_expr(&ts.expression), + Expression::TSNonNullExpression(ts) => calls_hooks_or_creates_jsx_in_expr(&ts.expression), + Expression::TSTypeAssertion(ts) => calls_hooks_or_creates_jsx_in_expr(&ts.expression), + Expression::TSInstantiationExpression(ts) => { + calls_hooks_or_creates_jsx_in_expr(&ts.expression) + } + Expression::TypeCastExpression(tc) => calls_hooks_or_creates_jsx_in_expr(&tc.expression), + Expression::NewExpression(new) => { + if calls_hooks_or_creates_jsx_in_expr(&new.callee) { + return true; + } + new.arguments.iter().any(|a| { + if matches!( + a, + Expression::ArrowFunctionExpression(_) | Expression::FunctionExpression(_) + ) { + return false; + } + calls_hooks_or_creates_jsx_in_expr(a) + }) + } + + // Skip nested functions + Expression::ArrowFunctionExpression(_) | Expression::FunctionExpression(_) => false, + + // Leaf expressions + _ => false, + } +} + +/// Check if a function body calls hooks or creates JSX. +fn calls_hooks_or_creates_jsx(body: &FunctionBody) -> bool { + match body { + FunctionBody::Block(block) => calls_hooks_or_creates_jsx_in_stmts(&block.body), + FunctionBody::Expression(expr) => calls_hooks_or_creates_jsx_in_expr(expr), + } +} + +/// Check if the function parameters are valid for a React component. +/// Components can have 0 params, 1 param (props), or 2 params (props + ref). +/// Check if a parameter's type annotation is valid for a React component prop. +/// Returns false for primitive type annotations that indicate this is NOT a +/// component. +fn is_valid_props_annotation(param: &PatternLike) -> bool { + let type_annotation = match param { + PatternLike::Identifier(id) => id.type_annotation.as_deref(), + PatternLike::ObjectPattern(op) => op.type_annotation.as_deref(), + PatternLike::ArrayPattern(ap) => ap.type_annotation.as_deref(), + PatternLike::AssignmentPattern(ap) => ap.type_annotation.as_deref(), + PatternLike::RestElement(re) => re.type_annotation.as_deref(), + PatternLike::MemberExpression(_) => None, + }; + let annot = match type_annotation { + Some(val) => val, + None => return true, // No annotation = valid + }; + let annot_type = match annot.get("type").and_then(|v| v.as_str()) { + Some(t) => t, + None => return true, + }; + match annot_type { + "TSTypeAnnotation" => { + let inner_type = annot + .get("typeAnnotation") + .and_then(|v| v.get("type")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + !matches!( + inner_type, + "TSArrayType" + | "TSBigIntKeyword" + | "TSBooleanKeyword" + | "TSConstructorType" + | "TSFunctionType" + | "TSLiteralType" + | "TSNeverKeyword" + | "TSNumberKeyword" + | "TSStringKeyword" + | "TSSymbolKeyword" + | "TSTupleType" + ) + } + "TypeAnnotation" => { + let inner_type = annot + .get("typeAnnotation") + .and_then(|v| v.get("type")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + !matches!( + inner_type, + "ArrayTypeAnnotation" + | "BooleanLiteralTypeAnnotation" + | "BooleanTypeAnnotation" + | "EmptyTypeAnnotation" + | "FunctionTypeAnnotation" + | "NullLiteralTypeAnnotation" + | "NumberLiteralTypeAnnotation" + | "NumberTypeAnnotation" + | "StringLiteralTypeAnnotation" + | "StringTypeAnnotation" + | "SymbolTypeAnnotation" + | "ThisTypeAnnotation" + | "TupleTypeAnnotation" + ) + } + "Noop" => true, + _ => true, + } +} + +fn is_valid_component_params(params: &[PatternLike]) -> bool { + if params.is_empty() { + return true; + } + if params.len() > 2 { + return false; + } + // First param cannot be a rest element + if matches!(params[0], PatternLike::RestElement(_)) { + return false; + } + // Check type annotation on first param + if !is_valid_props_annotation(¶ms[0]) { + return false; + } + if params.len() == 1 { + return true; + } + // If second param exists, it should look like a ref + if let PatternLike::Identifier(ref id) = params[1] { + id.name.contains("ref") || id.name.contains("Ref") + } else { + false + } +} + +// ----------------------------------------------------------------------- +// Unified function body type for traversal +// ----------------------------------------------------------------------- + +/// Abstraction over function body types to simplify traversal code +enum FunctionBody<'a> { + Block(&'a BlockStatement), + Expression(&'a Expression), +} + +// ----------------------------------------------------------------------- +// Function type detection +// ----------------------------------------------------------------------- + +/// Determine the React function type for a function, given the compilation mode +/// and the function's name and context. +/// +/// This is the Rust equivalent of `getReactFunctionType` in Program.ts. +fn get_react_function_type( + name: Option<&str>, + params: &[PatternLike], + body: &FunctionBody, + body_directives: &[Directive], + is_declaration: bool, + parent_callee_name: Option<&str>, + opts: &PluginOptions, + is_component_declaration: bool, + is_hook_declaration: bool, +) -> Option { + // Check for opt-in directives in the function body + if let FunctionBody::Block(_) = body { + let opt_in = try_find_directive_enabling_memoization(body_directives, opts); + if let Ok(Some(_)) = opt_in { + // If there's an opt-in directive, use name heuristics but fall back to Other + return Some( + get_component_or_hook_like(name, params, body, parent_callee_name) + .unwrap_or(ReactFunctionType::Other), + ); + } + } + + // Component and hook declarations are known components/hooks + // (Flow `component Foo() { ... }` and `hook useFoo() { ... }` syntax, + // detected via __componentDeclaration / __hookDeclaration from the Hermes + // parser) + let component_syntax_type = if is_declaration { + if is_component_declaration { + Some(ReactFunctionType::Component) + } else if is_hook_declaration { + Some(ReactFunctionType::Hook) + } else { + None + } + } else { + None + }; + + match opts.compilation_mode.as_str() { + "annotation" => { + // opt-ins were checked above + None + } + "infer" => { + // Check if this is a component or hook-like function + component_syntax_type + .or_else(|| get_component_or_hook_like(name, params, body, parent_callee_name)) + } + "syntax" => { + // In syntax mode, only compile declared components/hooks + component_syntax_type + } + "all" => Some( + get_component_or_hook_like(name, params, body, parent_callee_name) + .unwrap_or(ReactFunctionType::Other), + ), + _ => None, + } +} + +/// Determine if a function looks like a React component or hook based on +/// naming conventions and code patterns. +/// +/// Adapted from the ESLint rule at +/// https://github.com/facebook/react/blob/main/packages/eslint-plugin-react-hooks/src/RulesOfHooks.js +fn get_component_or_hook_like( + name: Option<&str>, + params: &[PatternLike], + body: &FunctionBody, + parent_callee_name: Option<&str>, +) -> Option { + if let Some(fn_name) = name { + if is_component_name(fn_name) { + // Check if it actually looks like a component + let is_component = calls_hooks_or_creates_jsx(body) + && is_valid_component_params(params) + && !returns_non_node_fn(params, body); + return if is_component { + Some(ReactFunctionType::Component) + } else { + None + }; + } else if is_hook_name(fn_name) { + // Hooks have hook invocations or JSX, but can take any # of arguments + return if calls_hooks_or_creates_jsx(body) { + Some(ReactFunctionType::Hook) + } else { + None + }; + } + } + + // For unnamed functions, check if they are forwardRef/memo callbacks + if let Some(callee_name) = parent_callee_name { + if callee_name == "forwardRef" || callee_name == "memo" { + return if calls_hooks_or_creates_jsx(body) { + Some(ReactFunctionType::Component) + } else { + None + }; + } + } + + None +} + +/// Extract the callee name from a CallExpression if it's a React API call +/// (forwardRef, memo, React.forwardRef, React.memo). +fn get_callee_name_if_react_api(callee: &Expression) -> Option<&str> { + match callee { + Expression::Identifier(id) => { + if id.name == "forwardRef" || id.name == "memo" { + Some(&id.name) + } else { + None + } + } + Expression::MemberExpression(member) => { + if let Expression::Identifier(obj) = member.object.as_ref() { + if obj.name == "React" { + if let Expression::Identifier(prop) = member.property.as_ref() { + if prop.name == "forwardRef" || prop.name == "memo" { + return Some(&prop.name); + } + } + } + } + None + } + _ => None, + } +} + +// ----------------------------------------------------------------------- +// SourceLocation conversion +// ----------------------------------------------------------------------- + +/// Convert an AST SourceLocation to a diagnostics SourceLocation +fn convert_loc(loc: &react_compiler_ast::common::SourceLocation) -> SourceLocation { + SourceLocation { + start: react_compiler_diagnostics::Position { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: react_compiler_diagnostics::Position { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + } +} + +fn base_node_loc(base: &BaseNode) -> Option { + base.loc.as_ref().map(convert_loc) +} + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- + +/// Convert CompilerDiagnostic details into serializable CompilerErrorItemInfo +/// items. +fn diagnostic_details_to_items( + d: &react_compiler_diagnostics::CompilerDiagnostic, + filename: Option<&str>, +) -> Option> { + let items: Vec = d + .details + .iter() + .map(|item| match item { + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc, + message, + identifier_name, + } => CompilerErrorItemInfo { + kind: "error".to_string(), + loc: loc.as_ref().map(|l| { + let mut logger_loc = diag_loc_to_logger_loc(l, filename); + logger_loc.identifier_name = identifier_name.clone(); + logger_loc + }), + message: message.clone(), + }, + react_compiler_diagnostics::CompilerDiagnosticDetail::Hint { message } => { + CompilerErrorItemInfo { + kind: "hint".to_string(), + loc: None, + message: Some(message.clone()), + } + } + }) + .collect(); + if items.is_empty() { + None + } else { + Some(items) + } +} + +/// Convert an optional AST SourceLocation to a LoggerSourceLocation with +/// filename. +fn to_logger_loc( + ast_loc: Option<&react_compiler_ast::common::SourceLocation>, + filename: Option<&str>, +) -> Option { + ast_loc.map(|loc| LoggerSourceLocation { + start: LoggerPosition { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: LoggerPosition { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + filename: filename.map(|s| s.to_string()), + identifier_name: loc.identifier_name.clone(), + }) +} + +/// Convert a diagnostics SourceLocation to a LoggerSourceLocation with +/// filename. +fn diag_loc_to_logger_loc(loc: &SourceLocation, filename: Option<&str>) -> LoggerSourceLocation { + LoggerSourceLocation { + start: LoggerPosition { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: LoggerPosition { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + filename: filename.map(|s| s.to_string()), + identifier_name: None, + } +} + +/// Convert diagnostic suggestions to logger suggestion infos. +fn suggestions_to_logger( + suggestions: &Option>, +) -> Option> { + suggestions.as_ref().map(|suggestions| { + suggestions + .iter() + .map(|s| { + let op = match s.op { + react_compiler_diagnostics::CompilerSuggestionOperation::InsertBefore => { + LoggerSuggestionOp::InsertBefore + } + react_compiler_diagnostics::CompilerSuggestionOperation::InsertAfter => { + LoggerSuggestionOp::InsertAfter + } + react_compiler_diagnostics::CompilerSuggestionOperation::Remove => { + LoggerSuggestionOp::Remove + } + react_compiler_diagnostics::CompilerSuggestionOperation::Replace => { + LoggerSuggestionOp::Replace + } + }; + LoggerSuggestionInfo { + description: s.description.clone(), + op, + range: s.range, + text: s.text.clone(), + } + }) + .collect() + }) +} + +/// Log an error as LoggerEvent(s) directly onto the ProgramContext. +fn log_error( + err: &CompilerError, + fn_ast_loc: Option<&react_compiler_ast::common::SourceLocation>, + context: &mut ProgramContext, +) { + // Use the filename from the AST node's loc (set by parser's sourceFilename + // option), not from plugin options (which may have a different prefix like + // '/'). + let source_filename = fn_ast_loc.and_then(|loc| loc.filename.as_deref()); + let fn_loc = to_logger_loc(fn_ast_loc, source_filename); + + // Detect simulated unknown exception (throwUnknownException__testonly). + // In TS, non-CompilerError exceptions are logged as PipelineError with the + // error message as data. Emit the same event shape. + let is_simulated_unknown = err.details.len() == 1 + && err.details.iter().all(|d| match d { + CompilerErrorOrDiagnostic::ErrorDetail(d) => { + d.category == ErrorCategory::Invariant && d.reason == "unexpected error" + } + _ => false, + }); + if is_simulated_unknown { + context.log_event(LoggerEvent::PipelineError { + fn_loc: fn_loc.clone(), + data: "Error: unexpected error".to_string(), + }); + return; + } + + for detail in &err.details { + let detail_info = match detail { + CompilerErrorOrDiagnostic::Diagnostic(d) => CompilerErrorDetailInfo { + category: format!("{:?}", d.category), + reason: d.reason.clone(), + description: d.description.clone(), + severity: format!("{:?}", d.logged_severity()), + suggestions: suggestions_to_logger(&d.suggestions), + details: diagnostic_details_to_items(d, source_filename), + loc: None, + }, + CompilerErrorOrDiagnostic::ErrorDetail(d) => CompilerErrorDetailInfo { + category: format!("{:?}", d.category), + reason: d.reason.clone(), + description: d.description.clone(), + severity: format!("{:?}", d.logged_severity()), + suggestions: suggestions_to_logger(&d.suggestions), + details: None, + loc: d + .loc + .as_ref() + .map(|l| diag_loc_to_logger_loc(l, source_filename)), + }, + }; + // Use CompileErrorWithLoc when fn_loc is present to match TS field ordering + if let Some(ref loc) = fn_loc { + context.log_event(LoggerEvent::CompileErrorWithLoc { + fn_loc: loc.clone(), + detail: detail_info, + }); + } else { + context.log_event(LoggerEvent::CompileError { + fn_loc: None, + detail: detail_info, + }); + } + } +} + +/// Handle an error according to the panicThreshold setting. +/// Returns Some(CompileResult::Error) if the error should be surfaced as fatal, +/// otherwise returns None (error was logged only). +fn handle_error( + err: &CompilerError, + fn_ast_loc: Option<&react_compiler_ast::common::SourceLocation>, + context: &mut ProgramContext, +) -> Option { + // Log the error + log_error(err, fn_ast_loc, context); + + let should_panic = match context.opts.panic_threshold.as_str() { + "all_errors" => true, + "critical_errors" => err.has_errors(), + _ => false, + }; + + // Config errors always cause a panic + let is_config_error = err.details.iter().any(|d| match d { + CompilerErrorOrDiagnostic::Diagnostic(d) => d.category == ErrorCategory::Config, + CompilerErrorOrDiagnostic::ErrorDetail(d) => d.category == ErrorCategory::Config, + }); + + if should_panic || is_config_error { + let source_fn = context.source_filename(); + let mut error_info = compiler_error_to_info(err, source_fn.as_deref()); + + // Detect simulated unknown exception (throwUnknownException__testonly). + // In the TS compiler, this throws a plain Error('unexpected error'), not + // a CompilerError. Set rawMessage so the JS side throws with the raw + // message instead of formatting through formatCompilerError(). + let is_simulated_unknown = err.details.len() == 1 + && err.details.iter().all(|d| match d { + CompilerErrorOrDiagnostic::ErrorDetail(d) => { + d.category == ErrorCategory::Invariant && d.reason == "unexpected error" + } + _ => false, + }); + if is_simulated_unknown { + error_info.raw_message = Some("unexpected error".to_string()); + } + + // Pre-format the error message in Rust when possible, so the JS + // shim can use it directly instead of calling formatCompilerError(). + if error_info.raw_message.is_none() { + if let Some(ref source) = context.code { + error_info.formatted_message = Some( + react_compiler_diagnostics::code_frame::format_compiler_error( + err, + source, + source_fn.as_deref(), + ), + ); + } + } + + Some(CompileResult::Error { + error: error_info, + events: context.events.clone(), + ordered_log: context.ordered_log.clone(), + timing: Vec::new(), + }) + } else { + None + } +} + +/// Convert a diagnostics CompilerError to a serializable CompilerErrorInfo. +fn compiler_error_to_info(err: &CompilerError, filename: Option<&str>) -> CompilerErrorInfo { + let details: Vec = err + .details + .iter() + .map(|d| match d { + CompilerErrorOrDiagnostic::Diagnostic(d) => CompilerErrorDetailInfo { + category: format!("{:?}", d.category), + reason: d.reason.clone(), + description: d.description.clone(), + severity: format!("{:?}", d.severity()), + suggestions: suggestions_to_logger(&d.suggestions), + details: diagnostic_details_to_items(d, filename), + loc: None, + }, + CompilerErrorOrDiagnostic::ErrorDetail(d) => CompilerErrorDetailInfo { + category: format!("{:?}", d.category), + reason: d.reason.clone(), + description: d.description.clone(), + severity: format!("{:?}", d.severity()), + suggestions: suggestions_to_logger(&d.suggestions), + details: None, + loc: d.loc.as_ref().map(|l| diag_loc_to_logger_loc(l, filename)), + }, + }) + .collect(); + + let (reason, description) = details + .first() + .map(|d| (d.reason.clone(), d.description.clone())) + .unwrap_or_else(|| ("Unknown error".to_string(), None)); + + CompilerErrorInfo { + reason, + description, + details, + raw_message: None, + formatted_message: None, + } +} + +// ----------------------------------------------------------------------- +// Compilation pipeline stubs +// ----------------------------------------------------------------------- + +/// Attempt to compile a single function. +/// +/// Returns `CodegenFunction` on success or `CompilerError` on failure. +/// Debug log entries are accumulated on `context.debug_logs`. +fn try_compile_function( + source: &CompileSource<'_>, + scope_info: &ScopeInfo, + output_mode: CompilerOutputMode, + env_config: &EnvironmentConfig, + context: &mut ProgramContext, +) -> Result { + // Check for suppressions that affect this function + if let (Some(start), Some(end)) = (source.fn_start, source.fn_end) { + let affecting = filter_suppressions_that_affect_function(&context.suppressions, start, end); + if !affecting.is_empty() { + let owned: Vec = affecting.into_iter().cloned().collect(); + let mut err = suppressions_to_compiler_error(&owned); + // Suppression errors are returned (not thrown), so they should NOT + // trigger CompileUnexpectedThrow. + err.is_thrown = false; + return Err(err); + } + } + + // Run the compilation pipeline + pipeline::compile_fn( + &source.fn_node, + source.fn_name.as_deref(), + scope_info, + source.fn_type, + output_mode, + env_config, + context, + ) +} + +/// Process a single function: check directives, attempt compilation, handle +/// results. +/// +/// Returns `Ok(Some(codegen_fn))` when the function was compiled and should be +/// applied, `Ok(None)` when the function was skipped or lint-only, +/// or `Err(CompileResult)` if a fatal error should short-circuit the program. +fn process_fn( + source: &CompileSource<'_>, + scope_info: &ScopeInfo, + output_mode: CompilerOutputMode, + env_config: &EnvironmentConfig, + context: &mut ProgramContext, +) -> Result, CompileResult> { + // Parse directives from the function body + let opt_in_result = + try_find_directive_enabling_memoization(&source.body_directives, &context.opts); + let opt_out = find_directive_disabling_memoization(&source.body_directives, &context.opts); + + // If parsing opt-in directive fails, handle the error and skip + let opt_in = match opt_in_result { + Ok(d) => d, + Err(err) => { + // Apply panic threshold logic (same as compilation errors) + if let Some(result) = handle_error(&err, source.fn_ast_loc.as_ref(), context) { + return Err(result); + } + return Ok(None); + } + }; + + // Attempt compilation + let compile_result = try_compile_function(source, scope_info, output_mode, env_config, context); + + match compile_result { + Err(err) => { + // Emit CompileUnexpectedThrow for errors that were "thrown" from a pass + // (not accumulated via env.record_error) and have all non-Invariant details. + // Matches TS tryCompileFunction() catch block behavior. + if err.is_thrown && err.is_all_non_invariant() { + let source_filename = source + .fn_ast_loc + .as_ref() + .and_then(|loc| loc.filename.as_deref()); + context.log_event(LoggerEvent::CompileUnexpectedThrow { + fn_loc: to_logger_loc(source.fn_ast_loc.as_ref(), source_filename), + data: err.to_string_for_event(), + }); + } + + if opt_out.is_some() { + // If there's an opt-out, just log the error (don't escalate) + log_error(&err, source.fn_ast_loc.as_ref(), context); + } else { + // Apply panic threshold logic + if let Some(result) = handle_error(&err, source.fn_ast_loc.as_ref(), context) { + return Err(result); + } + } + Ok(None) + } + Ok(codegen_fn) => { + // Check opt-out + if !context.opts.ignore_use_no_forget && opt_out.is_some() { + let opt_out_value = &opt_out.unwrap().value.value; + let source_filename = source + .fn_ast_loc + .as_ref() + .and_then(|loc| loc.filename.as_deref()); + context.log_event(LoggerEvent::CompileSkip { + fn_loc: to_logger_loc(source.fn_ast_loc.as_ref(), source_filename), + reason: format!("Skipped due to '{}' directive.", opt_out_value), + loc: opt_out.and_then(|d| to_logger_loc(d.base.loc.as_ref(), source_filename)), + }); + // Even though the function is skipped, register the memo cache import + // if the compiled function had memo slots. This matches TS behavior where + // addMemoCacheImport() is called during codegen as a side effect that + // persists even when the function is later skipped. + if codegen_fn.memo_slots_used > 0 { + context.add_memo_cache_import(); + } + return Ok(None); + } + + // Log success with memo stats from CodegenFunction + let source_filename = source + .fn_ast_loc + .as_ref() + .and_then(|loc| loc.filename.as_deref()); + context.log_event(LoggerEvent::CompileSuccess { + fn_loc: to_logger_loc(source.fn_ast_loc.as_ref(), source_filename), + fn_name: codegen_fn.id.as_ref().map(|id| id.name.clone()), + memo_slots: codegen_fn.memo_slots_used, + memo_blocks: codegen_fn.memo_blocks, + memo_values: codegen_fn.memo_values, + pruned_memo_blocks: codegen_fn.pruned_memo_blocks, + pruned_memo_values: codegen_fn.pruned_memo_values, + }); + + // Check module scope opt-out + if context.has_module_scope_opt_out { + return Ok(None); + } + + // Check output mode — lint mode doesn't apply compiled functions + if output_mode == CompilerOutputMode::Lint { + return Ok(None); + } + + // Check annotation mode + if context.opts.compilation_mode == "annotation" && opt_in.is_none() { + return Ok(None); + } + + Ok(Some(codegen_fn)) + } + } +} + +// ----------------------------------------------------------------------- +// Import checking +// ----------------------------------------------------------------------- + +/// Check if the program already has a `c` import from the React Compiler +/// runtime module. If so, the file was already compiled and should be skipped. +fn has_memo_cache_function_import(program: &Program, module_name: &str) -> bool { + for stmt in &program.body { + if let Statement::ImportDeclaration(import) = stmt { + if import.source.value == module_name { + for specifier in &import.specifiers { + if let ImportSpecifier::ImportSpecifier(data) = specifier { + let imported_name = match &data.imported { + ModuleExportName::Identifier(id) => &id.name, + ModuleExportName::StringLiteral(s) => &s.value, + }; + if imported_name == "c" { + return true; + } + } + } + } + } + } + false +} + +/// Check if compilation should be skipped for this program. +fn should_skip_compilation(program: &Program, options: &PluginOptions) -> bool { + let runtime_module = get_react_compiler_runtime_module(&options.target); + has_memo_cache_function_import(program, &runtime_module) +} + +// ----------------------------------------------------------------------- +// Function discovery +// ----------------------------------------------------------------------- + +/// Information about an expression that might be a function to compile +struct FunctionInfo<'a> { + name: Option, + fn_node: FunctionNode<'a>, + params: &'a [PatternLike], + body: FunctionBody<'a>, + body_directives: Vec, + base: &'a BaseNode, + parent_callee_name: Option, + /// True if the node has `__componentDeclaration` set by the Hermes parser + /// (Flow component syntax) + is_component_declaration: bool, + /// True if the node has `__hookDeclaration` set by the Hermes parser (Flow + /// hook syntax) + is_hook_declaration: bool, +} + +/// Extract function info from a FunctionDeclaration +fn fn_info_from_decl(decl: &FunctionDeclaration) -> FunctionInfo<'_> { + FunctionInfo { + name: get_function_name_from_id(decl.id.as_ref()), + fn_node: FunctionNode::FunctionDeclaration(decl), + params: &decl.params, + body: FunctionBody::Block(&decl.body), + body_directives: decl.body.directives.clone(), + base: &decl.base, + parent_callee_name: None, + is_component_declaration: decl.component_declaration, + is_hook_declaration: decl.hook_declaration, + } +} + +/// Extract function info from a FunctionExpression +fn fn_info_from_func_expr<'a>( + expr: &'a FunctionExpression, + inferred_name: Option, + parent_callee_name: Option, +) -> FunctionInfo<'a> { + FunctionInfo { + name: expr.id.as_ref().map(|id| id.name.clone()).or(inferred_name), + fn_node: FunctionNode::FunctionExpression(expr), + params: &expr.params, + body: FunctionBody::Block(&expr.body), + body_directives: expr.body.directives.clone(), + base: &expr.base, + parent_callee_name, + is_component_declaration: false, + is_hook_declaration: false, + } +} + +/// Extract function info from an ArrowFunctionExpression +fn fn_info_from_arrow<'a>( + expr: &'a ArrowFunctionExpression, + inferred_name: Option, + parent_callee_name: Option, +) -> FunctionInfo<'a> { + let (body, directives) = match expr.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + (FunctionBody::Block(block), block.directives.clone()) + } + ArrowFunctionBody::Expression(e) => (FunctionBody::Expression(e), Vec::new()), + }; + FunctionInfo { + name: inferred_name, + fn_node: FunctionNode::ArrowFunctionExpression(expr), + params: &expr.params, + body, + body_directives: directives, + base: &expr.base, + parent_callee_name, + is_component_declaration: false, + is_hook_declaration: false, + } +} + +/// Try to create a CompileSource from function info +fn try_make_compile_source<'a>( + info: FunctionInfo<'a>, + opts: &PluginOptions, + context: &mut ProgramContext, +) -> Option> { + // Skip if already compiled + if let Some(start) = info.base.start { + if context.is_already_compiled(start) { + return None; + } + } + + let fn_type = get_react_function_type( + info.name.as_deref(), + info.params, + &info.body, + &info.body_directives, + info.is_component_declaration || info.is_hook_declaration, + info.parent_callee_name.as_deref(), + opts, + info.is_component_declaration, + info.is_hook_declaration, + )?; + + // Mark as compiled + if let Some(start) = info.base.start { + context.mark_compiled(start); + } + + Some(CompileSource { + kind: CompileSourceKind::Original, + fn_node: info.fn_node, + fn_name: info.name, + fn_loc: base_node_loc(info.base), + fn_ast_loc: info.base.loc.clone(), + fn_start: info.base.start, + fn_end: info.base.end, + fn_type, + body_directives: info.body_directives, + }) +} + +/// Get the variable declarator name (for inferring function names from `const +/// Foo = () => {}`) +fn get_declarator_name(decl: &VariableDeclarator) -> Option { + match &decl.id { + PatternLike::Identifier(id) => Some(id.name.clone()), + _ => None, + } +} + +// ----------------------------------------------------------------------- +// FunctionDiscoveryVisitor — uses AstWalker to find compilable functions +// ----------------------------------------------------------------------- + +/// Visitor that discovers functions to compile, matching the TypeScript +/// compiler's Babel `program.traverse` behavior. +/// +/// Uses the `AstWalker` with `traverse_function_bodies` returning `false` +/// so we don't recurse into function bodies (similar to Babel's `fn.skip()`). +/// +/// Tracks parent context via: +/// - `current_declarator_name`: set by `enter_variable_declarator`, used to +/// infer function names from `const Foo = () => {}`. +/// - `parent_callee_stack`: set by `enter_call_expression`, used to detect +/// forwardRef/memo wrappers around function expressions. +/// +/// In 'all' mode, uses `scope_stack.len() > 1` to reject functions that are +/// not at program scope. The walker pushes the program scope first, then +/// nested scopes for for/switch/etc. — so `len() > 1` means the function +/// is inside a nested scope (not at program level), matching Babel's +/// `fn.scope.getProgramParent() !== fn.scope.parent` check. +struct FunctionDiscoveryVisitor<'a, 'ast> { + opts: &'a PluginOptions, + context: &'a mut ProgramContext, + queue: Vec>, + /// The inferred name from the current VariableDeclarator, if any. + current_declarator_name: Option, + /// Stack tracking callee names of enclosing CallExpressions. + /// `Some(name)` when the callee is a React API (forwardRef/memo), + /// `None` for other calls. + parent_callee_stack: Vec>, + /// Depth counter for loop expression positions (while.test, for-in.right, + /// etc.). When > 0, functions are treated as non-program-scope in 'all' + /// mode. + loop_expression_depth: usize, +} + +impl<'a, 'ast> FunctionDiscoveryVisitor<'a, 'ast> { + fn new(opts: &'a PluginOptions, context: &'a mut ProgramContext) -> Self { + Self { + opts, + context, + queue: Vec::new(), + current_declarator_name: None, + parent_callee_stack: Vec::new(), + loop_expression_depth: 0, + } + } + + /// Check if in 'all' mode and the function is inside a nested scope. + /// The walker pushes the function's own scope BEFORE calling enter hooks, + /// so scope_stack = [program, ...parents, function_scope]. A top-level + /// function has len=2 (program + function). Anything deeper means it's + /// inside a nested scope (for/switch/etc.) and should be rejected. + /// Also rejects functions found in loop expression positions (while.test, + /// for-in.right, etc.) where Babel treats the scope as non-program. + fn is_rejected_by_scope_check(&self, scope_stack: &[ScopeId]) -> bool { + self.opts.compilation_mode == "all" + && (scope_stack.len() > 2 || self.loop_expression_depth > 0) + } + + /// Get the current parent callee name (forwardRef/memo) if any. + fn current_parent_callee(&self) -> Option { + self.parent_callee_stack.last().and_then(|opt| opt.clone()) + } +} + +impl<'a, 'ast> Visitor<'ast> for FunctionDiscoveryVisitor<'a, 'ast> { + fn traverse_function_bodies(&self) -> bool { + false // Don't recurse into function bodies (like Babel's fn.skip()) + } + + fn enter_loop_expression(&mut self) { + self.loop_expression_depth += 1; + } + + fn leave_loop_expression(&mut self) { + self.loop_expression_depth -= 1; + } + + fn enter_variable_declarator( + &mut self, + node: &'ast VariableDeclarator, + _scope_stack: &[ScopeId], + ) { + // Only infer the declarator name when the init is a direct function + // expression, arrow, or call expression (for forwardRef/memo wrappers). + // TS checks `path.parentPath.isVariableDeclarator()` which only matches + // when the function IS the init, not when it's nested inside an object, + // array, or other expression. + if let Some(ref init) = node.init { + match init.as_ref() { + Expression::FunctionExpression(_) + | Expression::ArrowFunctionExpression(_) + | Expression::CallExpression(_) => { + self.current_declarator_name = get_declarator_name(node); + } + _ => {} + } + } + } + + fn leave_variable_declarator( + &mut self, + _node: &'ast VariableDeclarator, + _scope_stack: &[ScopeId], + ) { + self.current_declarator_name = None; + } + + fn enter_call_expression(&mut self, node: &'ast CallExpression, _scope_stack: &[ScopeId]) { + let callee_name = get_callee_name_if_react_api(&node.callee).map(|s| s.to_string()); + // In TS, the declarator name only flows through forwardRef/memo calls + // (path.parentPath.isCallExpression() checks the callee). For any other + // call expression, clear the name so nested functions don't inherit it. + if callee_name.is_none() { + self.current_declarator_name = None; + } + self.parent_callee_stack.push(callee_name); + } + + fn leave_call_expression(&mut self, _node: &'ast CallExpression, _scope_stack: &[ScopeId]) { + let was_react_api = self + .parent_callee_stack + .pop() + .and_then(|name| name) + .is_some(); + // After a forwardRef/memo call finishes, clear the declarator name. + // The name is only valid within the call's arguments — if a function + // inside consumed it via .take(), great; if not, it shouldn't leak + // to sibling or subsequent expressions. + if was_react_api { + self.current_declarator_name = None; + } + } + + fn enter_function_declaration( + &mut self, + node: &'ast FunctionDeclaration, + scope_stack: &[ScopeId], + ) { + if self.is_rejected_by_scope_check(scope_stack) { + return; + } + let info = fn_info_from_decl(node); + if let Some(source) = try_make_compile_source(info, self.opts, self.context) { + self.queue.push(source); + } + } + + fn enter_function_expression( + &mut self, + node: &'ast FunctionExpression, + scope_stack: &[ScopeId], + ) { + if self.is_rejected_by_scope_check(scope_stack) { + return; + } + let inferred_name = node + .id + .as_ref() + .map(|id| id.name.clone()) + .or_else(|| self.current_declarator_name.take()); + let parent_callee = self.current_parent_callee(); + let info = fn_info_from_func_expr(node, inferred_name, parent_callee); + if let Some(source) = try_make_compile_source(info, self.opts, self.context) { + self.queue.push(source); + } + } + + fn enter_arrow_function_expression( + &mut self, + node: &'ast ArrowFunctionExpression, + scope_stack: &[ScopeId], + ) { + if self.is_rejected_by_scope_check(scope_stack) { + return; + } + let inferred_name = self.current_declarator_name.take(); + let parent_callee = self.current_parent_callee(); + let info = fn_info_from_arrow(node, inferred_name, parent_callee); + if let Some(source) = try_make_compile_source(info, self.opts, self.context) { + self.queue.push(source); + } + } +} + +/// Find all functions in the program that should be compiled. +/// +/// Uses the `AstWalker` with a `FunctionDiscoveryVisitor` to traverse +/// the entire program, discovering functions at any depth. The visitor +/// uses `traverse_function_bodies() -> false` to skip recursing into +/// function bodies (matching Babel's `fn.skip()` behavior). +/// +/// The visitor tracks parent context (VariableDeclarator names for +/// `const Foo = () => {}`, CallExpression callees for forwardRef/memo +/// wrappers) via enter/leave hooks. +/// +/// Skips classes and their contents (the walker does not recurse into +/// class bodies). +fn find_functions_to_compile<'a>( + program: &'a Program, + opts: &PluginOptions, + context: &mut ProgramContext, + scope: &ScopeInfo, +) -> Vec> { + let mut visitor = FunctionDiscoveryVisitor::new(opts, context); + let mut walker = AstWalker::new(scope); + walker.walk_program(&mut visitor, program); + visitor.queue +} + +// ----------------------------------------------------------------------- +// Main entry point +// ----------------------------------------------------------------------- + +/// A successfully compiled function, ready to be applied to the AST. +struct CompiledFunction<'a> { + #[allow(dead_code)] + kind: CompileSourceKind, + #[allow(dead_code)] + source: &'a CompileSource<'a>, + codegen_fn: CodegenFunction, +} + +/// The type of the original function node, used to determine what kind of +/// replacement node to create. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum OriginalFnKind { + FunctionDeclaration, + FunctionExpression, + ArrowFunctionExpression, +} + +/// Owned representation of a compiled function for AST replacement. +/// Does not borrow from the original program, so we can mutate the AST. +struct CompiledFnForReplacement { + /// Start position of the original function, used to find it in the AST. + fn_start: Option, + /// The kind of the original function node. + original_kind: OriginalFnKind, + /// The compiled codegen output. + codegen_fn: CodegenFunction, + /// Whether this is an original function (vs outlined). Gating only applies + /// to original. + #[allow(dead_code)] + source_kind: CompileSourceKind, + /// The function name, if any. + fn_name: Option, + /// Gating configuration (from dynamic gating or plugin options). + gating: Option, +} + +/// Check if a compiled function is referenced before its declaration at the top +/// level. This is needed for the gating rewrite: hoisted function declarations +/// that are referenced before their declaration site need a special gating +/// pattern. +fn get_functions_referenced_before_declaration( + program: &Program, + compiled_fns: &[CompiledFnForReplacement], +) -> HashSet { + // Collect function names and their start positions for compiled + // FunctionDeclarations + let mut fn_names: HashMap = HashMap::new(); + for compiled in compiled_fns { + if compiled.original_kind == OriginalFnKind::FunctionDeclaration { + if let Some(ref name) = compiled.fn_name { + if let Some(start) = compiled.fn_start { + fn_names.insert(name.clone(), start); + } + } + } + } + + if fn_names.is_empty() { + return HashSet::new(); + } + + let mut referenced_before_decl: HashSet = HashSet::new(); + + // Walk through program body in order. For each statement, check if it + // references any of the function names before the function's declaration. + for stmt in &program.body { + // Check if this statement IS one of the function declarations + if let Statement::FunctionDeclaration(f) = stmt { + if let Some(ref id) = f.id { + fn_names.remove(&id.name); + } + } + // For all remaining tracked names, check if the statement references them + // at the top level (not inside nested functions) + for (name, start) in &fn_names { + if stmt_references_identifier_at_top_level(stmt, name) { + referenced_before_decl.insert(*start); + } + } + } + + referenced_before_decl +} + +/// Check if a statement references an identifier at the top level (not inside +/// nested functions). +fn stmt_references_identifier_at_top_level(stmt: &Statement, name: &str) -> bool { + match stmt { + Statement::FunctionDeclaration(_) => { + // Don't look inside function declarations (they create their own scope) + false + } + Statement::ExportDefaultDeclaration(export) => match export.declaration.as_ref() { + ExportDefaultDecl::Expression(e) => expr_references_identifier_at_top_level(e, name), + _ => false, + }, + Statement::ExportNamedDeclaration(export) => { + if let Some(ref decl) = export.declaration { + match decl.as_ref() { + Declaration::VariableDeclaration(var_decl) => { + var_decl.declarations.iter().any(|d| { + d.init + .as_ref() + .map_or(false, |e| expr_references_identifier_at_top_level(e, name)) + }) + } + _ => false, + } + } else { + // export { Name } - check specifiers + export.specifiers.iter().any(|s| { + if let react_compiler_ast::declarations::ExportSpecifier::ExportSpecifier( + spec, + ) = s + { + match &spec.local { + ModuleExportName::Identifier(id) => id.name == name, + _ => false, + } + } else { + false + } + }) + } + } + Statement::VariableDeclaration(var_decl) => var_decl.declarations.iter().any(|d| { + d.init + .as_ref() + .map_or(false, |e| expr_references_identifier_at_top_level(e, name)) + }), + Statement::ExpressionStatement(expr_stmt) => { + expr_references_identifier_at_top_level(&expr_stmt.expression, name) + } + Statement::ReturnStatement(ret) => ret + .argument + .as_ref() + .map_or(false, |e| expr_references_identifier_at_top_level(e, name)), + _ => false, + } +} + +/// Check if an expression references an identifier at the top level. +fn expr_references_identifier_at_top_level(expr: &Expression, name: &str) -> bool { + match expr { + Expression::Identifier(id) => id.name == name, + Expression::CallExpression(call) => { + expr_references_identifier_at_top_level(&call.callee, name) + || call + .arguments + .iter() + .any(|a| expr_references_identifier_at_top_level(a, name)) + } + Expression::MemberExpression(member) => { + expr_references_identifier_at_top_level(&member.object, name) + } + Expression::ConditionalExpression(cond) => { + expr_references_identifier_at_top_level(&cond.test, name) + || expr_references_identifier_at_top_level(&cond.consequent, name) + || expr_references_identifier_at_top_level(&cond.alternate, name) + } + Expression::BinaryExpression(bin) => { + expr_references_identifier_at_top_level(&bin.left, name) + || expr_references_identifier_at_top_level(&bin.right, name) + } + Expression::LogicalExpression(log) => { + expr_references_identifier_at_top_level(&log.left, name) + || expr_references_identifier_at_top_level(&log.right, name) + } + // Don't recurse into function expressions/arrows (they create their own scope) + Expression::FunctionExpression(_) | Expression::ArrowFunctionExpression(_) => false, + _ => false, + } +} + +/// Build a function expression from a codegen function (compiled output). +fn build_compiled_function_expression(codegen: &CodegenFunction) -> Expression { + Expression::FunctionExpression(FunctionExpression { + base: BaseNode::typed("FunctionExpression"), + id: codegen.id.clone(), + params: codegen.params.clone(), + body: codegen.body.clone(), + generator: codegen.generator, + is_async: codegen.is_async, + return_type: None, + type_parameters: None, + }) +} + +/// Build a function expression that preserves the original function's +/// structure. For FunctionDeclarations, converts to FunctionExpression. +/// For ArrowFunctionExpressions, keeps as-is. +fn clone_original_fn_as_expression(stmt: &Statement, start: u32) -> Option { + match stmt { + Statement::FunctionDeclaration(f) => { + if f.base.start == Some(start) { + return Some(Expression::FunctionExpression(FunctionExpression { + base: BaseNode::typed("FunctionExpression"), + id: f.id.clone(), + params: f.params.clone(), + body: f.body.clone(), + generator: f.generator, + is_async: f.is_async, + return_type: None, + type_parameters: None, + })); + } + None + } + Statement::VariableDeclaration(var_decl) => { + for d in &var_decl.declarations { + if let Some(ref init) = d.init { + if let Some(e) = clone_original_expr_as_expression(init, start) { + return Some(e); + } + } + } + None + } + Statement::ExportDefaultDeclaration(export) => match export.declaration.as_ref() { + ExportDefaultDecl::FunctionDeclaration(f) => { + if f.base.start == Some(start) { + return Some(Expression::FunctionExpression(FunctionExpression { + base: BaseNode::typed("FunctionExpression"), + id: f.id.clone(), + params: f.params.clone(), + body: f.body.clone(), + generator: f.generator, + is_async: f.is_async, + return_type: None, + type_parameters: None, + })); + } + None + } + ExportDefaultDecl::Expression(e) => clone_original_expr_as_expression(e, start), + _ => None, + }, + Statement::ExportNamedDeclaration(export) => { + if let Some(ref decl) = export.declaration { + match decl.as_ref() { + Declaration::FunctionDeclaration(f) => { + if f.base.start == Some(start) { + return Some(Expression::FunctionExpression(FunctionExpression { + base: BaseNode::typed("FunctionExpression"), + id: f.id.clone(), + params: f.params.clone(), + body: f.body.clone(), + generator: f.generator, + is_async: f.is_async, + return_type: None, + type_parameters: None, + })); + } + None + } + Declaration::VariableDeclaration(var_decl) => { + for d in &var_decl.declarations { + if let Some(ref init) = d.init { + if let Some(e) = clone_original_expr_as_expression(init, start) { + return Some(e); + } + } + } + None + } + _ => None, + } + } else { + None + } + } + Statement::ExpressionStatement(expr_stmt) => { + clone_original_expr_as_expression(&expr_stmt.expression, start) + } + // Recurse into block-containing statements + Statement::BlockStatement(block) => { + for s in &block.body { + if let Some(e) = clone_original_fn_as_expression(s, start) { + return Some(e); + } + } + None + } + Statement::IfStatement(if_stmt) => { + if let Some(e) = clone_original_expr_as_expression(&if_stmt.test, start) { + return Some(e); + } + if let Some(e) = clone_original_fn_as_expression(&if_stmt.consequent, start) { + return Some(e); + } + if let Some(ref alt) = if_stmt.alternate { + if let Some(e) = clone_original_fn_as_expression(alt, start) { + return Some(e); + } + } + None + } + Statement::TryStatement(try_stmt) => { + for s in &try_stmt.block.body { + if let Some(e) = clone_original_fn_as_expression(s, start) { + return Some(e); + } + } + if let Some(ref handler) = try_stmt.handler { + for s in &handler.body.body { + if let Some(e) = clone_original_fn_as_expression(s, start) { + return Some(e); + } + } + } + if let Some(ref finalizer) = try_stmt.finalizer { + for s in &finalizer.body { + if let Some(e) = clone_original_fn_as_expression(s, start) { + return Some(e); + } + } + } + None + } + Statement::SwitchStatement(switch_stmt) => { + if let Some(e) = clone_original_expr_as_expression(&switch_stmt.discriminant, start) { + return Some(e); + } + for case in &switch_stmt.cases { + for s in &case.consequent { + if let Some(e) = clone_original_fn_as_expression(s, start) { + return Some(e); + } + } + } + None + } + Statement::LabeledStatement(labeled) => { + clone_original_fn_as_expression(&labeled.body, start) + } + Statement::ForStatement(for_stmt) => { + if let Some(ref init) = for_stmt.init { + match init.as_ref() { + ForInit::VariableDeclaration(var_decl) => { + for d in &var_decl.declarations { + if let Some(ref init_expr) = d.init { + if let Some(e) = clone_original_expr_as_expression(init_expr, start) + { + return Some(e); + } + } + } + } + ForInit::Expression(expr) => { + if let Some(e) = clone_original_expr_as_expression(expr, start) { + return Some(e); + } + } + } + } + if let Some(ref test) = for_stmt.test { + if let Some(e) = clone_original_expr_as_expression(test, start) { + return Some(e); + } + } + if let Some(ref update) = for_stmt.update { + if let Some(e) = clone_original_expr_as_expression(update, start) { + return Some(e); + } + } + clone_original_fn_as_expression(&for_stmt.body, start) + } + Statement::WhileStatement(while_stmt) => { + if let Some(e) = clone_original_expr_as_expression(&while_stmt.test, start) { + return Some(e); + } + clone_original_fn_as_expression(&while_stmt.body, start) + } + Statement::DoWhileStatement(do_while) => { + if let Some(e) = clone_original_expr_as_expression(&do_while.test, start) { + return Some(e); + } + clone_original_fn_as_expression(&do_while.body, start) + } + Statement::ForInStatement(for_in) => { + if let Some(e) = clone_original_expr_as_expression(&for_in.right, start) { + return Some(e); + } + clone_original_fn_as_expression(&for_in.body, start) + } + Statement::ForOfStatement(for_of) => { + if let Some(e) = clone_original_expr_as_expression(&for_of.right, start) { + return Some(e); + } + clone_original_fn_as_expression(&for_of.body, start) + } + Statement::WithStatement(with_stmt) => { + if let Some(e) = clone_original_expr_as_expression(&with_stmt.object, start) { + return Some(e); + } + clone_original_fn_as_expression(&with_stmt.body, start) + } + Statement::ReturnStatement(ret) => { + if let Some(ref arg) = ret.argument { + clone_original_expr_as_expression(arg, start) + } else { + None + } + } + Statement::ThrowStatement(throw_stmt) => { + clone_original_expr_as_expression(&throw_stmt.argument, start) + } + _ => None, + } +} + +/// Clone an expression node for use as the original (fallback) in gating. +fn clone_original_expr_as_expression(expr: &Expression, start: u32) -> Option { + match expr { + Expression::FunctionExpression(f) => { + if f.base.start == Some(start) { + return Some(Expression::FunctionExpression(f.clone())); + } + None + } + Expression::ArrowFunctionExpression(f) => { + if f.base.start == Some(start) { + return Some(Expression::ArrowFunctionExpression(f.clone())); + } + None + } + Expression::CallExpression(call) => { + for arg in &call.arguments { + if let Some(e) = clone_original_expr_as_expression(arg, start) { + return Some(e); + } + } + None + } + Expression::ObjectExpression(obj) => { + for prop in &obj.properties { + match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + if let Some(e) = clone_original_expr_as_expression(&p.value, start) { + return Some(e); + } + } + ObjectExpressionProperty::SpreadElement(s) => { + if let Some(e) = clone_original_expr_as_expression(&s.argument, start) { + return Some(e); + } + } + _ => {} + } + } + None + } + Expression::ArrayExpression(arr) => { + for elem in arr.elements.iter().flatten() { + if let Some(e) = clone_original_expr_as_expression(elem, start) { + return Some(e); + } + } + None + } + Expression::AssignmentExpression(assign) => { + clone_original_expr_as_expression(&assign.right, start) + } + Expression::SequenceExpression(seq) => { + for e in &seq.expressions { + if let Some(e) = clone_original_expr_as_expression(e, start) { + return Some(e); + } + } + None + } + Expression::ConditionalExpression(cond) => { + if let Some(e) = clone_original_expr_as_expression(&cond.consequent, start) { + return Some(e); + } + clone_original_expr_as_expression(&cond.alternate, start) + } + Expression::ParenthesizedExpression(paren) => { + clone_original_expr_as_expression(&paren.expression, start) + } + _ => None, + } +} + +/// Build a compiled arrow/function expression from a codegen function, +/// matching the original expression kind. +fn build_compiled_expression_matching_kind( + codegen: &CodegenFunction, + original_kind: OriginalFnKind, +) -> Expression { + match original_kind { + OriginalFnKind::ArrowFunctionExpression => { + Expression::ArrowFunctionExpression(ArrowFunctionExpression { + base: BaseNode::typed("ArrowFunctionExpression"), + params: codegen.params.clone(), + body: Box::new(ArrowFunctionBody::BlockStatement(codegen.body.clone())), + id: None, + generator: codegen.generator, + is_async: codegen.is_async, + expression: Some(false), + return_type: None, + type_parameters: None, + predicate: None, + }) + } + _ => build_compiled_function_expression(codegen), + } +} + +/// Apply compiled functions back to the AST by replacing original function +/// nodes with their compiled versions, inserting outlined functions, and adding +/// imports. +fn apply_compiled_functions( + compiled_fns: &[CompiledFnForReplacement], + program: &mut Program, + context: &mut ProgramContext, +) { + if compiled_fns.is_empty() { + return; + } + + // Check if any compiled functions have gating enabled + let has_gating = compiled_fns.iter().any(|cf| cf.gating.is_some()); + + // If gating is enabled, determine which functions are referenced before + // declaration + let referenced_before_decl = if has_gating { + get_functions_referenced_before_declaration(program, compiled_fns) + } else { + HashSet::new() + }; + + // For gated functions, we need to clone the original function expressions + // BEFORE we start mutating the AST. + let original_expressions: Vec> = if has_gating { + compiled_fns + .iter() + .map(|compiled| { + if compiled.gating.is_some() { + if let Some(start) = compiled.fn_start { + for stmt in program.body.iter() { + if let Some(expr) = clone_original_fn_as_expression(stmt, start) { + return Some(expr); + } + } + } + None + } else { + None + } + }) + .collect() + } else { + compiled_fns.iter().map(|_| None).collect() + }; + + // Collect outlined functions to insert (as FunctionDeclarations). + // For FunctionDeclarations: insert right after the parent (matching TS + // insertAfter behavior) For FunctionExpression/ArrowFunctionExpression: + // append at end of program body (matching TS pushContainer behavior) + let mut outlined_decls: Vec<(Option, OriginalFnKind, FunctionDeclaration)> = Vec::new(); + + // Replace each compiled function in the AST + for (idx, compiled) in compiled_fns.iter().enumerate() { + // Collect outlined functions for this compiled function + for outlined in &compiled.codegen_fn.outlined { + let outlined_decl = FunctionDeclaration { + base: BaseNode::typed("FunctionDeclaration"), + id: outlined.func.id.clone(), + params: outlined.func.params.clone(), + body: outlined.func.body.clone(), + generator: outlined.func.generator, + is_async: outlined.func.is_async, + declare: None, + return_type: None, + type_parameters: None, + predicate: None, + component_declaration: false, + hook_declaration: false, + }; + outlined_decls.push((compiled.fn_start, compiled.original_kind, outlined_decl)); + } + + if let Some(ref gating_config) = compiled.gating { + let is_ref_before_decl = compiled + .fn_start + .map_or(false, |s| referenced_before_decl.contains(&s)); + + if is_ref_before_decl && compiled.original_kind == OriginalFnKind::FunctionDeclaration { + // Use the hoisted function declaration gating pattern + apply_gated_function_hoisted(program, compiled, gating_config, context); + } else { + // Use the conditional expression gating pattern + let original_expr = original_expressions[idx].clone(); + apply_gated_function_conditional( + program, + compiled, + gating_config, + original_expr, + context, + ); + } + } else { + // No gating: replace the function directly (original behavior) + if let Some(start) = compiled.fn_start { + let mut visitor = ReplaceFnVisitor { start, compiled }; + walk_program_mut(&mut visitor, program); + } + } + } + + // Insert outlined function declarations. + let mut insert_decls: Vec<(Option, FunctionDeclaration)> = Vec::new(); + let mut push_decls: Vec = Vec::new(); + + for (parent_start, original_kind, outlined_decl) in outlined_decls { + match original_kind { + OriginalFnKind::FunctionDeclaration => { + insert_decls.push((parent_start, outlined_decl)); + } + OriginalFnKind::FunctionExpression | OriginalFnKind::ArrowFunctionExpression => { + push_decls.push(outlined_decl); + } + } + } + + for (parent_start, outlined_decl) in insert_decls.into_iter() { + let insert_idx = if let Some(start) = parent_start { + program + .body + .iter() + .position(|stmt| stmt_has_fn_at_start(stmt, start)) + .map(|pos| pos + 1) + .unwrap_or(program.body.len()) + } else { + program.body.len() + }; + program + .body + .insert(insert_idx, Statement::FunctionDeclaration(outlined_decl)); + } + + for outlined_decl in push_decls { + program + .body + .push(Statement::FunctionDeclaration(outlined_decl)); + } + + // Register the memo cache import and rename useMemoCache references. + let needs_memo_import = compiled_fns + .iter() + .any(|cf| cf.codegen_fn.memo_slots_used > 0); + if needs_memo_import { + let import_spec = context.add_memo_cache_import(); + let local_name = import_spec.name; + let mut visitor = RenameIdentifierVisitor { + old_name: "useMemoCache", + new_name: &local_name, + }; + walk_program_mut(&mut visitor, program); + } + + // Instrumentation and hook guard imports are pre-registered in compile_program + // before compilation, so they are already in the imports map. No post-hoc + // renaming needed since codegen uses the pre-resolved local names. + + add_imports_to_program(program, context); +} + +/// Apply the conditional expression gating pattern. +/// +/// For function declarations (non-export-default, non-hoisted): +/// `function Foo(props) { ... }` -> `const Foo = gating() ? function Foo(...) +/// { compiled } : function Foo(...) { original };` +/// +/// For export default function with name: +/// `export default function Foo(props) { ... }` -> `const Foo = gating() ? +/// ... : ...; export default Foo;` +/// +/// For export named function: +/// `export function Foo(props) { ... }` -> `export const Foo = gating() ? ... +/// : ...;` +/// +/// For arrow/function expressions: +/// Replace the expression inline with `gating() ? compiled : original` +fn apply_gated_function_conditional( + program: &mut Program, + compiled: &CompiledFnForReplacement, + gating_config: &GatingConfig, + original_expr: Option, + context: &mut ProgramContext, +) { + let start = match compiled.fn_start { + Some(s) => s, + None => return, + }; + + // Add the gating import + let gating_import = context.add_import_specifier( + &gating_config.source, + &gating_config.import_specifier_name, + None, + ); + let gating_callee_name = gating_import.name; + + // Build the compiled expression + let compiled_expr = + build_compiled_expression_matching_kind(&compiled.codegen_fn, compiled.original_kind); + + // Build the original (fallback) expression + let original_expr = match original_expr { + Some(e) => e, + None => return, // shouldn't happen + }; + + // Build: gating() ? compiled : original + let gating_expression = Expression::ConditionalExpression(ConditionalExpression { + base: BaseNode::typed("ConditionalExpression"), + test: Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: gating_callee_name, + type_annotation: None, + optional: None, + decorators: None, + })), + arguments: vec![], + type_parameters: None, + type_arguments: None, + optional: None, + })), + consequent: Box::new(compiled_expr), + alternate: Box::new(original_expr), + }); + + // Find and replace the function in the program body. + // We need to track if this was an export default function with a name, + // because we need to insert `export default Name;` after the replacement. + let mut export_default_name: Option<(usize, String)> = None; + + for (idx, stmt) in program.body.iter().enumerate() { + if let Statement::ExportDefaultDeclaration(export) = stmt { + if let ExportDefaultDecl::FunctionDeclaration(f) = export.declaration.as_ref() { + if f.base.start == Some(start) { + if let Some(ref fn_id) = f.id { + export_default_name = Some((idx, fn_id.name.clone())); + } + } + } + } + } + + let mut visitor = ReplaceWithGatedVisitor { + start, + gating_expression: &gating_expression, + }; + walk_program_mut(&mut visitor, program); + + // If this was an export default function with a name, insert `export default + // Name;` after + if let Some((idx, name)) = export_default_name { + program.body.insert( + idx + 1, + Statement::ExportDefaultDeclaration(ExportDefaultDeclaration { + base: BaseNode::typed("ExportDefaultDeclaration"), + declaration: Box::new(ExportDefaultDecl::Expression(Box::new( + Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name, + type_annotation: None, + optional: None, + decorators: None, + }), + ))), + export_kind: None, + }), + ); + } +} + +/// Visitor that replaces a function with a gated conditional expression. +struct ReplaceWithGatedVisitor<'a> { + start: u32, + gating_expression: &'a Expression, +} + +impl MutVisitor for ReplaceWithGatedVisitor<'_> { + fn visit_statement(&mut self, stmt: &mut Statement) -> VisitResult { + // FunctionDeclaration → replace with `const Foo = gating() ? ... : ...;` + if let Statement::FunctionDeclaration(f) = &*stmt { + if f.base.start == Some(self.start) { + let fn_name = f.id.clone().unwrap_or_else(|| Identifier { + base: BaseNode::typed("Identifier"), + name: "anonymous".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }); + let mut base = BaseNode::typed("VariableDeclaration"); + base.leading_comments = f.base.leading_comments.clone(); + base.trailing_comments = f.base.trailing_comments.clone(); + base.inner_comments = f.base.inner_comments.clone(); + *stmt = Statement::VariableDeclaration(VariableDeclaration { + base, + kind: VariableDeclarationKind::Const, + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::Identifier(fn_name), + init: Some(Box::new(self.gating_expression.clone())), + definite: None, + }], + declare: None, + }); + return VisitResult::Stop; + } + } + + // ExportDefaultDeclaration with FunctionDeclaration + if let Statement::ExportDefaultDeclaration(export) = stmt { + let is_fn_decl_match = matches!( + export.declaration.as_ref(), + ExportDefaultDecl::FunctionDeclaration(f) if f.base.start == Some(self.start) + ); + if is_fn_decl_match { + if let ExportDefaultDecl::FunctionDeclaration(f) = export.declaration.as_ref() { + let fn_name = f.id.clone(); + if let Some(fn_id) = fn_name { + let mut base = BaseNode::typed("VariableDeclaration"); + base.leading_comments = export.base.leading_comments.clone(); + base.trailing_comments = export.base.trailing_comments.clone(); + base.inner_comments = export.base.inner_comments.clone(); + *stmt = Statement::VariableDeclaration(VariableDeclaration { + base, + kind: VariableDeclarationKind::Const, + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::Identifier(fn_id), + init: Some(Box::new(self.gating_expression.clone())), + definite: None, + }], + declare: None, + }); + return VisitResult::Stop; + } else { + export.declaration = Box::new(ExportDefaultDecl::Expression(Box::new( + self.gating_expression.clone(), + ))); + return VisitResult::Stop; + } + } + } + // Expression case handled by walker recursion into visit_expression + } + + // ExportNamedDeclaration with FunctionDeclaration + if let Statement::ExportNamedDeclaration(export) = stmt { + if let Some(ref mut decl) = export.declaration { + if let Declaration::FunctionDeclaration(f) = decl.as_mut() { + if f.base.start == Some(self.start) { + let fn_name = f.id.clone().unwrap_or_else(|| Identifier { + base: BaseNode::typed("Identifier"), + name: "anonymous".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }); + *decl = Box::new(Declaration::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + kind: VariableDeclarationKind::Const, + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::Identifier(fn_name), + init: Some(Box::new(self.gating_expression.clone())), + definite: None, + }], + declare: None, + })); + return VisitResult::Stop; + } + } + } + } + + VisitResult::Continue + } + + fn visit_expression(&mut self, expr: &mut Expression) -> VisitResult { + match expr { + Expression::FunctionExpression(f) if f.base.start == Some(self.start) => { + *expr = self.gating_expression.clone(); + VisitResult::Stop + } + Expression::ArrowFunctionExpression(f) if f.base.start == Some(self.start) => { + *expr = self.gating_expression.clone(); + VisitResult::Stop + } + _ => VisitResult::Continue, + } + } +} + +/// Apply the hoisted function declaration gating pattern. +/// +/// This is used when a function declaration is referenced before its +/// declaration site. Instead of wrapping in a conditional expression (which +/// would break hoisting), we: +/// 1. Rename the original function to `Foo_unoptimized` +/// 2. Insert a compiled function as `Foo_optimized` +/// 3. Insert a `const gating_result = gating()` before +/// 4. Insert a new `function Foo(arg0, ...) { if (gating_result) return +/// Foo_optimized(...); else return Foo_unoptimized(...); }` after +fn apply_gated_function_hoisted( + program: &mut Program, + compiled: &CompiledFnForReplacement, + gating_config: &GatingConfig, + context: &mut ProgramContext, +) { + let start = match compiled.fn_start { + Some(s) => s, + None => return, + }; + + let original_fn_name = match &compiled.fn_name { + Some(name) => name.clone(), + None => return, + }; + + // Add the gating import + let gating_import = context.add_import_specifier( + &gating_config.source, + &gating_config.import_specifier_name, + None, + ); + let gating_callee_name = gating_import.name.clone(); + + // Generate unique names + let gating_result_name = context.new_uid(&format!("{}_result", gating_callee_name)); + let unoptimized_name = context.new_uid(&format!("{}_unoptimized", original_fn_name)); + let optimized_name = context.new_uid(&format!("{}_optimized", original_fn_name)); + + // Find the original function declaration and determine its params + let mut original_params: Vec = Vec::new(); + let mut fn_stmt_idx: Option = None; + + for (idx, stmt) in program.body.iter().enumerate() { + if let Statement::FunctionDeclaration(f) = stmt { + if f.base.start == Some(start) { + original_params = f.params.clone(); + fn_stmt_idx = Some(idx); + break; + } + } + } + + let fn_idx = match fn_stmt_idx { + Some(idx) => idx, + None => return, + }; + + // Rename the original function to `_unoptimized` + if let Statement::FunctionDeclaration(f) = &mut program.body[fn_idx] { + if let Some(ref mut id) = f.id { + id.name = unoptimized_name.clone(); + } + } + + // Build the optimized function declaration (compiled version with renamed id) + let compiled_fn_decl = FunctionDeclaration { + base: BaseNode::typed("FunctionDeclaration"), + id: Some(Identifier { + base: BaseNode::typed("Identifier"), + name: optimized_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + }), + params: compiled.codegen_fn.params.clone(), + body: compiled.codegen_fn.body.clone(), + generator: compiled.codegen_fn.generator, + is_async: compiled.codegen_fn.is_async, + declare: None, + return_type: None, + type_parameters: None, + predicate: None, + component_declaration: false, + hook_declaration: false, + }; + + // Build the gating result variable: `const gating_result = gating();` + let gating_result_stmt = Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + kind: VariableDeclarationKind::Const, + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: gating_result_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + }), + init: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: gating_callee_name, + type_annotation: None, + optional: None, + decorators: None, + })), + arguments: vec![], + type_parameters: None, + type_arguments: None, + optional: None, + }))), + definite: None, + }], + declare: None, + }); + + // Build new params and args for the dispatcher function + let num_params = original_params.len(); + let mut new_params: Vec = Vec::new(); + let mut optimized_args: Vec = Vec::new(); + let mut unoptimized_args: Vec = Vec::new(); + + for i in 0..num_params { + let arg_name = format!("arg{}", i); + let is_rest = matches!(&original_params[i], PatternLike::RestElement(_)); + + if is_rest { + new_params.push(PatternLike::RestElement( + react_compiler_ast::patterns::RestElement { + base: BaseNode::typed("RestElement"), + argument: Box::new(PatternLike::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: arg_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + })), + type_annotation: None, + decorators: None, + }, + )); + optimized_args.push(Expression::SpreadElement(SpreadElement { + base: BaseNode::typed("SpreadElement"), + argument: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: arg_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + })), + })); + unoptimized_args.push(Expression::SpreadElement(SpreadElement { + base: BaseNode::typed("SpreadElement"), + argument: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: arg_name, + type_annotation: None, + optional: None, + decorators: None, + })), + })); + } else { + new_params.push(PatternLike::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: arg_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + })); + optimized_args.push(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: arg_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + })); + unoptimized_args.push(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: arg_name, + type_annotation: None, + optional: None, + decorators: None, + })); + } + } + + // Build the dispatcher function: + // function Foo(arg0, ...) { + // if (gating_result) return Foo_optimized(arg0, ...); + // else return Foo_unoptimized(arg0, ...); + // } + let dispatcher_fn = Statement::FunctionDeclaration(FunctionDeclaration { + base: BaseNode::typed("FunctionDeclaration"), + id: Some(Identifier { + base: BaseNode::typed("Identifier"), + name: original_fn_name, + type_annotation: None, + optional: None, + decorators: None, + }), + params: new_params, + body: BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![Statement::IfStatement(IfStatement { + base: BaseNode::typed("IfStatement"), + test: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: gating_result_name, + type_annotation: None, + optional: None, + decorators: None, + })), + consequent: Box::new(Statement::ReturnStatement(ReturnStatement { + base: BaseNode::typed("ReturnStatement"), + argument: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: optimized_name.clone(), + type_annotation: None, + optional: None, + decorators: None, + })), + arguments: optimized_args, + type_parameters: None, + type_arguments: None, + optional: None, + }))), + })), + alternate: Some(Box::new(Statement::ReturnStatement(ReturnStatement { + base: BaseNode::typed("ReturnStatement"), + argument: Some(Box::new(Expression::CallExpression(CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(Identifier { + base: BaseNode::typed("Identifier"), + name: unoptimized_name, + type_annotation: None, + optional: None, + decorators: None, + })), + arguments: unoptimized_args, + type_parameters: None, + type_arguments: None, + optional: None, + }))), + }))), + })], + directives: vec![], + }, + generator: false, + is_async: false, + declare: None, + return_type: None, + type_parameters: None, + predicate: None, + component_declaration: false, + hook_declaration: false, + }); + + // Insert nodes. The TS code uses insertBefore for the gating result and + // optimized fn, and insertAfter for the dispatcher. The order in the output + // should be: ... (existing statements before fn_idx) ... + // const gating_result = gating(); <- inserted before + // function Foo_optimized() { ... } <- inserted before + // function Foo_unoptimized() { ... } <- the original (renamed) + // function Foo(arg0) { ... } <- inserted after + // ... (existing statements after fn_idx) ... + // + // insertBefore inserts before the target, and insertAfter inserts after. + // We insert in reverse order for insertAfter. + + // Insert dispatcher after the original (now renamed) function + program.body.insert(fn_idx + 1, dispatcher_fn); + + // Insert optimized function before the original + program + .body + .insert(fn_idx, Statement::FunctionDeclaration(compiled_fn_decl)); + + // Insert gating result before the optimized function + program.body.insert(fn_idx, gating_result_stmt); +} + +/// Check if a statement contains a function whose BaseNode.start matches. +fn stmt_has_fn_at_start(stmt: &Statement, start: u32) -> bool { + match stmt { + Statement::FunctionDeclaration(f) => f.base.start == Some(start), + Statement::VariableDeclaration(var_decl) => var_decl.declarations.iter().any(|decl| { + if let Some(ref init) = decl.init { + expr_has_fn_at_start(init, start) + } else { + false + } + }), + Statement::ExportDefaultDeclaration(export) => match export.declaration.as_ref() { + ExportDefaultDecl::FunctionDeclaration(f) => f.base.start == Some(start), + ExportDefaultDecl::Expression(e) => expr_has_fn_at_start(e, start), + _ => false, + }, + Statement::ExportNamedDeclaration(export) => { + if let Some(ref decl) = export.declaration { + match decl.as_ref() { + Declaration::FunctionDeclaration(f) => f.base.start == Some(start), + Declaration::VariableDeclaration(var_decl) => { + var_decl.declarations.iter().any(|d| { + if let Some(ref init) = d.init { + expr_has_fn_at_start(init, start) + } else { + false + } + }) + } + _ => false, + } + } else { + false + } + } + Statement::ExpressionStatement(expr_stmt) => { + expr_has_fn_at_start(&expr_stmt.expression, start) + } + // Recurse into block-containing statements + Statement::BlockStatement(block) => { + block.body.iter().any(|s| stmt_has_fn_at_start(s, start)) + } + Statement::IfStatement(if_stmt) => { + expr_has_fn_at_start(&if_stmt.test, start) + || stmt_has_fn_at_start(&if_stmt.consequent, start) + || if_stmt + .alternate + .as_ref() + .map_or(false, |alt| stmt_has_fn_at_start(alt, start)) + } + Statement::TryStatement(try_stmt) => { + try_stmt + .block + .body + .iter() + .any(|s| stmt_has_fn_at_start(s, start)) + || try_stmt.handler.as_ref().map_or(false, |h| { + h.body.body.iter().any(|s| stmt_has_fn_at_start(s, start)) + }) + || try_stmt.finalizer.as_ref().map_or(false, |f| { + f.body.iter().any(|s| stmt_has_fn_at_start(s, start)) + }) + } + Statement::SwitchStatement(switch_stmt) => { + expr_has_fn_at_start(&switch_stmt.discriminant, start) + || switch_stmt.cases.iter().any(|case| { + case.consequent + .iter() + .any(|s| stmt_has_fn_at_start(s, start)) + }) + } + Statement::LabeledStatement(labeled) => stmt_has_fn_at_start(&labeled.body, start), + Statement::ForStatement(for_stmt) => { + if let Some(ref init) = for_stmt.init { + match init.as_ref() { + ForInit::VariableDeclaration(var_decl) => { + if var_decl.declarations.iter().any(|d| { + d.init + .as_ref() + .map_or(false, |e| expr_has_fn_at_start(e, start)) + }) { + return true; + } + } + ForInit::Expression(expr) => { + if expr_has_fn_at_start(expr, start) { + return true; + } + } + } + } + if for_stmt + .test + .as_ref() + .map_or(false, |t| expr_has_fn_at_start(t, start)) + { + return true; + } + if for_stmt + .update + .as_ref() + .map_or(false, |u| expr_has_fn_at_start(u, start)) + { + return true; + } + stmt_has_fn_at_start(&for_stmt.body, start) + } + Statement::WhileStatement(while_stmt) => { + expr_has_fn_at_start(&while_stmt.test, start) + || stmt_has_fn_at_start(&while_stmt.body, start) + } + Statement::DoWhileStatement(do_while) => { + expr_has_fn_at_start(&do_while.test, start) + || stmt_has_fn_at_start(&do_while.body, start) + } + Statement::ForInStatement(for_in) => { + expr_has_fn_at_start(&for_in.right, start) || stmt_has_fn_at_start(&for_in.body, start) + } + Statement::ForOfStatement(for_of) => { + expr_has_fn_at_start(&for_of.right, start) || stmt_has_fn_at_start(&for_of.body, start) + } + Statement::WithStatement(with_stmt) => { + expr_has_fn_at_start(&with_stmt.object, start) + || stmt_has_fn_at_start(&with_stmt.body, start) + } + Statement::ReturnStatement(ret) => ret + .argument + .as_ref() + .map_or(false, |arg| expr_has_fn_at_start(arg, start)), + Statement::ThrowStatement(throw_stmt) => expr_has_fn_at_start(&throw_stmt.argument, start), + _ => false, + } +} + +/// Check if an expression contains a function whose BaseNode.start matches. +fn expr_has_fn_at_start(expr: &Expression, start: u32) -> bool { + match expr { + Expression::FunctionExpression(f) => f.base.start == Some(start), + Expression::ArrowFunctionExpression(f) => f.base.start == Some(start), + // Check for forwardRef/memo wrappers: the inner function + Expression::CallExpression(call) => call + .arguments + .iter() + .any(|arg| expr_has_fn_at_start(arg, start)), + _ => false, + } +} + +/// Visitor that replaces a compiled function in the AST by matching +/// `base.start`. +struct ReplaceFnVisitor<'a> { + start: u32, + compiled: &'a CompiledFnForReplacement, +} + +impl MutVisitor for ReplaceFnVisitor<'_> { + fn visit_statement(&mut self, stmt: &mut Statement) -> VisitResult { + match stmt { + Statement::FunctionDeclaration(f) if f.base.start == Some(self.start) => { + f.id = self.compiled.codegen_fn.id.clone(); + f.params = self.compiled.codegen_fn.params.clone(); + f.body = self.compiled.codegen_fn.body.clone(); + f.generator = self.compiled.codegen_fn.generator; + f.is_async = self.compiled.codegen_fn.is_async; + f.return_type = None; + f.type_parameters = None; + f.predicate = None; + f.declare = None; + return VisitResult::Stop; + } + Statement::ExportDefaultDeclaration(export) => { + if let ExportDefaultDecl::FunctionDeclaration(f) = export.declaration.as_mut() { + if f.base.start == Some(self.start) { + f.id = self.compiled.codegen_fn.id.clone(); + f.params = self.compiled.codegen_fn.params.clone(); + f.body = self.compiled.codegen_fn.body.clone(); + f.generator = self.compiled.codegen_fn.generator; + f.is_async = self.compiled.codegen_fn.is_async; + f.return_type = None; + f.type_parameters = None; + f.predicate = None; + f.declare = None; + return VisitResult::Stop; + } + } + } + Statement::ExportNamedDeclaration(export) => { + if let Some(ref mut decl) = export.declaration { + if let Declaration::FunctionDeclaration(f) = decl.as_mut() { + if f.base.start == Some(self.start) { + f.id = self.compiled.codegen_fn.id.clone(); + f.params = self.compiled.codegen_fn.params.clone(); + f.body = self.compiled.codegen_fn.body.clone(); + f.generator = self.compiled.codegen_fn.generator; + f.is_async = self.compiled.codegen_fn.is_async; + f.return_type = None; + f.type_parameters = None; + f.predicate = None; + f.declare = None; + return VisitResult::Stop; + } + } + } + } + _ => {} + } + VisitResult::Continue + } + + fn visit_expression(&mut self, expr: &mut Expression) -> VisitResult { + match expr { + Expression::FunctionExpression(f) if f.base.start == Some(self.start) => { + f.id = self.compiled.codegen_fn.id.clone(); + f.params = self.compiled.codegen_fn.params.clone(); + f.body = self.compiled.codegen_fn.body.clone(); + f.generator = self.compiled.codegen_fn.generator; + f.is_async = self.compiled.codegen_fn.is_async; + f.return_type = None; + f.type_parameters = None; + VisitResult::Stop + } + Expression::ArrowFunctionExpression(f) if f.base.start == Some(self.start) => { + f.params = self.compiled.codegen_fn.params.clone(); + f.body = Box::new(ArrowFunctionBody::BlockStatement( + self.compiled.codegen_fn.body.clone(), + )); + f.generator = self.compiled.codegen_fn.generator; + f.is_async = self.compiled.codegen_fn.is_async; + f.expression = Some(false); + f.return_type = None; + f.type_parameters = None; + f.predicate = None; + VisitResult::Stop + } + _ => VisitResult::Continue, + } + } +} + +/// Visitor that renames all occurrences of an identifier in expression +/// position. +struct RenameIdentifierVisitor<'a> { + old_name: &'a str, + new_name: &'a str, +} + +impl MutVisitor for RenameIdentifierVisitor<'_> { + fn visit_identifier(&mut self, node: &mut Identifier) -> VisitResult { + if node.name == self.old_name { + node.name = self.new_name.to_string(); + } + VisitResult::Continue + } +} + +/// Main entry point for the React Compiler. +/// +/// Receives a full program AST, scope information (unused for now), and +/// resolved options. Returns a CompileResult indicating whether the AST was +/// modified, along with any logger events. +/// +/// This function implements the logic from the TS entrypoint (Program.ts): +/// - shouldSkipCompilation: check for existing runtime imports +/// - validateRestrictedImports: check for blocklisted imports +/// - findProgramSuppressions: find eslint/flow suppression comments +/// - findFunctionsToCompile: traverse program to find components and hooks +/// - processFn: per-function compilation with directive and suppression +/// handling +/// - applyCompiledFunctions: replace original functions with compiled versions +pub fn compile_program(mut file: File, scope: ScopeInfo, options: PluginOptions) -> CompileResult { + // Compute output mode once, up front + let output_mode = CompilerOutputMode::from_opts(&options); + + // Create a temporary context for early-return paths (before full context is set + // up) + let early_events: Vec = Vec::new(); + let mut early_ordered_log: Vec = Vec::new(); + + // Log environment config for debugLogIRs + if options.debug { + early_ordered_log.push(OrderedLogItem::Debug { + entry: DebugLogEntry::new( + "EnvironmentConfig", + serde_json::to_string_pretty(&options.environment).unwrap_or_default(), + ), + }); + } + + // Check if we should compile this file at all (pre-resolved by JS shim) + if !options.should_compile { + return CompileResult::Success { + ast: None, + events: early_events, + ordered_log: early_ordered_log, + renames: Vec::new(), + timing: Vec::new(), + }; + } + + let program = &file.program; + + // Check for existing runtime imports (file already compiled) + if should_skip_compilation(program, &options) { + return CompileResult::Success { + ast: None, + events: early_events, + ordered_log: early_ordered_log, + renames: Vec::new(), + timing: Vec::new(), + }; + } + + // Validate restricted imports from the environment config + let restricted_imports = options.environment.validate_blocklisted_imports.clone(); + + // Determine if we should check for eslint suppressions + let validate_exhaustive = options + .environment + .validate_exhaustive_memoization_dependencies; + let validate_hooks = options.environment.validate_hooks_usage; + + let eslint_rules: Option> = if validate_exhaustive && validate_hooks { + // Don't check for ESLint suppressions if both validations are enabled + None + } else { + Some(options.eslint_suppression_rules.clone().unwrap_or_else(|| { + DEFAULT_ESLINT_SUPPRESSIONS + .iter() + .map(|s| s.to_string()) + .collect() + })) + }; + + // Find program-level suppressions from comments + let suppressions = find_program_suppressions( + &file.comments, + eslint_rules.as_deref(), + options.flow_suppressions, + ); + + // Check for module-scope opt-out directive + let has_module_scope_opt_out = + find_directive_disabling_memoization(&program.directives, &options).is_some(); + + // Create program context + let mut context = ProgramContext::new( + options.clone(), + options.filename.clone(), + // Pass the source code for fast refresh hash computation. + options.source_code.clone(), + suppressions, + has_module_scope_opt_out, + ); + + // Extract the source filename from the AST (set by parser's sourceFilename + // option). This is the bare filename (e.g., "foo.ts") without path + // prefixes, which the TS compiler uses in logger event source locations. + let source_filename = program + .base + .loc + .as_ref() + .and_then(|loc| loc.filename.clone()) + .or_else(|| { + // Fallback: try the first statement's loc + program.body.first().and_then(|stmt| { + let base = match stmt { + react_compiler_ast::statements::Statement::ExpressionStatement(s) => &s.base, + react_compiler_ast::statements::Statement::VariableDeclaration(s) => &s.base, + react_compiler_ast::statements::Statement::FunctionDeclaration(s) => &s.base, + _ => return None, + }; + base.loc.as_ref().and_then(|loc| loc.filename.clone()) + }) + }); + context.set_source_filename(source_filename); + + // Initialize known referenced names from scope bindings for UID collision + // detection + context.init_from_scope(&scope); + + // Seed context with early ordered log entries + context.ordered_log.extend(early_ordered_log); + + // Validate restricted imports (needs context for handle_error) + if let Some(err) = validate_restricted_imports(program, &restricted_imports) { + if let Some(result) = handle_error(&err, None, &mut context) { + return result; + } + return CompileResult::Success { + ast: None, + events: context.events, + ordered_log: context.ordered_log, + renames: convert_renames(&context.renames), + timing: Vec::new(), + }; + } + + // Pre-register instrumentation imports to get stable local names. + // These are needed before compilation so codegen can use the correct names. + let instrument_fn_name: Option; + let instrument_gating_name: Option; + let hook_guard_name: Option; + + if let Some(ref instrument_config) = options.environment.enable_emit_instrument_forget { + let fn_spec = context.add_import_specifier( + &instrument_config.fn_.source, + &instrument_config.fn_.import_specifier_name, + None, + ); + instrument_fn_name = Some(fn_spec.name.clone()); + instrument_gating_name = instrument_config.gating.as_ref().map(|g| { + let spec = context.add_import_specifier(&g.source, &g.import_specifier_name, None); + spec.name.clone() + }); + } else { + instrument_fn_name = None; + instrument_gating_name = None; + } + + if let Some(ref hook_guard_config) = options.environment.enable_emit_hook_guards { + let spec = context.add_import_specifier( + &hook_guard_config.source, + &hook_guard_config.import_specifier_name, + None, + ); + hook_guard_name = Some(spec.name.clone()); + } else { + hook_guard_name = None; + } + + // Store pre-resolved names on context for pipeline access + context.instrument_fn_name = instrument_fn_name; + context.instrument_gating_name = instrument_gating_name; + context.hook_guard_name = hook_guard_name; + + // Find all functions to compile + let queue = find_functions_to_compile(program, &options, &mut context, &scope); + + // Clone env_config once for all function compilations (avoids per-function + // clone while satisfying the borrow checker — compile_fn needs &mut context + // + &env_config) + let env_config = options.environment.clone(); + + // Process each function and collect compiled results + let mut compiled_fns: Vec> = Vec::new(); + + for source in &queue { + match process_fn(source, &scope, output_mode, &env_config, &mut context) { + Ok(Some(codegen_fn)) => { + compiled_fns.push(CompiledFunction { + kind: source.kind, + source, + codegen_fn, + }); + } + Ok(None) => { + // Function was skipped or lint-only + } + Err(fatal_result) => { + return fatal_result; + } + } + } + + // Emit CompileSuccess events for JSX-outlined functions (fn_type.is_some()). + // In TS, outlined functions from outlineJSX are appended to the compilation + // queue and processed after all original functions, so their events appear + // at the end. Regular outlined functions (from OutlineFunctions pass) don't + // get separate events. + for compiled in &compiled_fns { + for outlined in &compiled.codegen_fn.outlined { + if outlined.fn_type.is_some() { + context.log_event(LoggerEvent::CompileSuccess { + fn_loc: None, + fn_name: outlined.func.id.as_ref().map(|id| id.name.clone()), + memo_slots: outlined.func.memo_slots_used, + memo_blocks: outlined.func.memo_blocks, + memo_values: outlined.func.memo_values, + pruned_memo_blocks: outlined.func.pruned_memo_blocks, + pruned_memo_values: outlined.func.pruned_memo_values, + }); + } + } + } + + // TS invariant: if there's a module scope opt-out, no functions should have + // been compiled + if has_module_scope_opt_out { + if !compiled_fns.is_empty() { + let mut err = CompilerError::new(); + err.push_error_detail(CompilerErrorDetail::new( + ErrorCategory::Invariant, + "Unexpected compiled functions when module scope opt-out is present", + )); + handle_error(&err, None, &mut context); + } + return CompileResult::Success { + ast: None, + events: context.events, + ordered_log: context.ordered_log, + renames: convert_renames(&context.renames), + timing: Vec::new(), + }; + } + + // Determine gating for each compiled function. + // In the TS compiler, dynamic gating from directives takes precedence over + // plugin-level gating. Gating only applies to 'original' functions, not + // 'outlined' ones. + let function_gating_config = options.gating.clone(); + + // Convert compiled functions to owned representations (dropping borrows) + // so we can mutate the AST. + let replacements: Vec = compiled_fns + .into_iter() + .map(|cf| { + let original_kind = match cf.source.fn_node { + FunctionNode::FunctionDeclaration(_) => OriginalFnKind::FunctionDeclaration, + FunctionNode::FunctionExpression(_) => OriginalFnKind::FunctionExpression, + FunctionNode::ArrowFunctionExpression(_) => OriginalFnKind::ArrowFunctionExpression, + }; + // Determine per-function gating: dynamic gating from directives OR plugin-level + // gating. Dynamic gating (from `use memo if(identifier)`) takes + // precedence. + let gating = if cf.kind == CompileSourceKind::Original { + // Check body directives for dynamic gating + let dynamic_gating = + find_directives_dynamic_gating(&cf.source.body_directives, &options) + .ok() + .flatten() + .map(|r| r.gating); + dynamic_gating.or_else(|| function_gating_config.clone()) + } else { + None + }; + CompiledFnForReplacement { + fn_start: cf.source.fn_start, + original_kind, + codegen_fn: cf.codegen_fn, + source_kind: cf.kind, + fn_name: cf.source.fn_name.clone(), + gating, + } + }) + .collect(); + // Drop queue (and its borrows from file.program) + drop(queue); + + if replacements.is_empty() { + // No functions to replace. Return renames for the Babel plugin to apply + // (e.g., variable shadowing renames in lint mode). Imports are NOT added + // when there are no replacements — matching TS behavior where + // addImportsToProgram is only called when compiledFns.length > 0. + return CompileResult::Success { + ast: None, + events: context.events, + ordered_log: context.ordered_log, + renames: convert_renames(&context.renames), + timing: Vec::new(), + }; + } + + // Now we can mutate file.program + apply_compiled_functions(&replacements, &mut file.program, &mut context); + + // Serialize the modified File AST directly to a JSON string and wrap as + // RawValue. This avoids double-serialization (File→Value→String) by going + // File→String directly. The RawValue is embedded verbatim when the + // CompileResult is serialized. + let ast = match serde_json::to_string(&file) { + Ok(s) => match serde_json::value::RawValue::from_string(s) { + Ok(raw) => Some(raw), + Err(e) => { + eprintln!("RUST COMPILER: Failed to create RawValue: {}", e); + None + } + }, + Err(e) => { + eprintln!("RUST COMPILER: Failed to serialize AST: {}", e); + None + } + }; + + let timing_entries = context.timing.into_entries(); + + CompileResult::Success { + ast, + events: context.events, + ordered_log: context.ordered_log, + renames: convert_renames(&context.renames), + timing: timing_entries, + } +} + +/// Convert internal BindingRename structs to the serializable BindingRenameInfo +/// format. +fn convert_renames( + renames: &[react_compiler_hir::environment::BindingRename], +) -> Vec { + renames + .iter() + .map(|r| BindingRenameInfo { + original: r.original.clone(), + renamed: r.renamed.clone(), + declaration_start: r.declaration_start, + }) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_hook_name() { + assert!(is_hook_name("useState")); + assert!(is_hook_name("useEffect")); + assert!(is_hook_name("use0Something")); + assert!(!is_hook_name("use")); + assert!(!is_hook_name("useless")); // lowercase after use + assert!(!is_hook_name("foo")); + assert!(!is_hook_name("")); + } + + #[test] + fn test_is_component_name() { + assert!(is_component_name("MyComponent")); + assert!(is_component_name("App")); + assert!(!is_component_name("myComponent")); + assert!(!is_component_name("app")); + assert!(!is_component_name("")); + } + + #[test] + fn test_is_valid_identifier() { + assert!(is_valid_identifier("foo")); + assert!(is_valid_identifier("_bar")); + assert!(is_valid_identifier("$baz")); + assert!(is_valid_identifier("foo123")); + assert!(!is_valid_identifier("")); + assert!(!is_valid_identifier("123foo")); + assert!(!is_valid_identifier("foo bar")); + } + + #[test] + fn test_is_valid_component_params_empty() { + assert!(is_valid_component_params(&[])); + } + + #[test] + fn test_is_valid_component_params_one_identifier() { + let params = vec![PatternLike::Identifier(Identifier { + base: BaseNode::default(), + name: "props".to_string(), + type_annotation: None, + optional: None, + decorators: None, + })]; + assert!(is_valid_component_params(¶ms)); + } + + #[test] + fn test_is_valid_component_params_too_many() { + let params = vec![ + PatternLike::Identifier(Identifier { + base: BaseNode::default(), + name: "a".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + PatternLike::Identifier(Identifier { + base: BaseNode::default(), + name: "b".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + PatternLike::Identifier(Identifier { + base: BaseNode::default(), + name: "c".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + ]; + assert!(!is_valid_component_params(¶ms)); + } + + #[test] + fn test_is_valid_component_params_with_ref() { + let params = vec![ + PatternLike::Identifier(Identifier { + base: BaseNode::default(), + name: "props".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + PatternLike::Identifier(Identifier { + base: BaseNode::default(), + name: "ref".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + ]; + assert!(is_valid_component_params(¶ms)); + } + + #[test] + fn test_should_skip_compilation_no_import() { + let program = Program { + base: BaseNode::default(), + body: vec![], + directives: vec![], + source_type: react_compiler_ast::SourceType::Module, + interpreter: None, + source_file: None, + }; + let options = PluginOptions { + should_compile: true, + enable_reanimated: false, + is_dev: false, + filename: None, + compilation_mode: "infer".to_string(), + panic_threshold: "none".to_string(), + target: super::super::plugin_options::CompilerTarget::Version("19".to_string()), + gating: None, + dynamic_gating: None, + no_emit: false, + output_mode: None, + eslint_suppression_rules: None, + flow_suppressions: true, + ignore_use_no_forget: false, + custom_opt_out_directives: None, + environment: EnvironmentConfig::default(), + source_code: None, + profiling: false, + debug: false, + }; + assert!(!should_skip_compilation(&program, &options)); + } +} diff --git a/crates/react_compiler/src/entrypoint/suppression.rs b/crates/react_compiler/src/entrypoint/suppression.rs new file mode 100644 index 000000000000..f2b58036ff73 --- /dev/null +++ b/crates/react_compiler/src/entrypoint/suppression.rs @@ -0,0 +1,311 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +use react_compiler_ast::common::{Comment, CommentData}; +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, CompilerSuggestion, + CompilerSuggestionOperation, ErrorCategory, +}; + +#[derive(Debug, Clone)] +pub enum SuppressionSource { + Eslint, + Flow, +} + +/// Captures the start and end range of a pair of eslint-disable ... +/// eslint-enable comments. In the case of a CommentLine or a relevant Flow +/// suppression, both the disable and enable point to the same comment. +/// +/// The enable comment can be missing in the case where only a disable block is +/// present, ie the rest of the file has potential React violations. +#[derive(Debug, Clone)] +pub struct SuppressionRange { + pub disable_comment: CommentData, + pub enable_comment: Option, + pub source: SuppressionSource, +} + +fn comment_data(comment: &Comment) -> &CommentData { + match comment { + Comment::CommentBlock(data) | Comment::CommentLine(data) => data, + } +} + +/// Check if a comment value matches `eslint-disable-next-line ` for any +/// rule in `rule_names`. +fn matches_eslint_disable_next_line(value: &str, rule_names: &[String]) -> bool { + if let Some(rest) = value.strip_prefix("eslint-disable-next-line ") { + return rule_names + .iter() + .any(|name| rest.starts_with(name.as_str())); + } + // Also check with leading space (comment values often have leading whitespace) + let trimmed = value.trim_start(); + if let Some(rest) = trimmed.strip_prefix("eslint-disable-next-line ") { + return rule_names + .iter() + .any(|name| rest.starts_with(name.as_str())); + } + false +} + +/// Check if a comment value matches `eslint-disable ` for any rule in +/// `rule_names`. +fn matches_eslint_disable(value: &str, rule_names: &[String]) -> bool { + if let Some(rest) = value.strip_prefix("eslint-disable ") { + return rule_names + .iter() + .any(|name| rest.starts_with(name.as_str())); + } + let trimmed = value.trim_start(); + if let Some(rest) = trimmed.strip_prefix("eslint-disable ") { + return rule_names + .iter() + .any(|name| rest.starts_with(name.as_str())); + } + false +} + +/// Check if a comment value matches `eslint-enable ` for any rule in +/// `rule_names`. +fn matches_eslint_enable(value: &str, rule_names: &[String]) -> bool { + if let Some(rest) = value.strip_prefix("eslint-enable ") { + return rule_names + .iter() + .any(|name| rest.starts_with(name.as_str())); + } + let trimmed = value.trim_start(); + if let Some(rest) = trimmed.strip_prefix("eslint-enable ") { + return rule_names + .iter() + .any(|name| rest.starts_with(name.as_str())); + } + false +} + +/// Check if a comment value matches a Flow suppression pattern. +/// Matches: $FlowFixMe[react-rule, $FlowFixMe_xxx[react-rule, +/// $FlowExpectedError[react-rule, $FlowIssue[react-rule +fn matches_flow_suppression(value: &str) -> bool { + // Find "$Flow" anywhere in the value + let Some(idx) = value.find("$Flow") else { + return false; + }; + let after_dollar_flow = &value[idx + "$Flow".len()..]; + + // Match FlowFixMe (with optional word chars), FlowExpectedError, or FlowIssue + let after_kind = if after_dollar_flow.starts_with("FixMe") { + // Skip "FixMe" + any word characters + let rest = &after_dollar_flow["FixMe".len()..]; + let word_end = rest + .find(|c: char| !c.is_alphanumeric() && c != '_') + .unwrap_or(rest.len()); + &rest[word_end..] + } else if after_dollar_flow.starts_with("ExpectedError") { + &after_dollar_flow["ExpectedError".len()..] + } else if after_dollar_flow.starts_with("Issue") { + &after_dollar_flow["Issue".len()..] + } else { + return false; + }; + + // Must be followed by "[react-rule" + after_kind.starts_with("[react-rule") +} + +/// Parse eslint-disable/enable and Flow suppression comments from program +/// comments. Equivalent to findProgramSuppressions in Suppression.ts +pub fn find_program_suppressions( + comments: &[Comment], + rule_names: Option<&[String]>, + flow_suppressions: bool, +) -> Vec { + let mut suppression_ranges: Vec = Vec::new(); + let mut disable_comment: Option = None; + let mut enable_comment: Option = None; + let mut source: Option = None; + + let has_rules = matches!(rule_names, Some(names) if !names.is_empty()); + + for comment in comments { + let data = comment_data(comment); + + if data.start.is_none() || data.end.is_none() { + continue; + } + + // Check for eslint-disable-next-line (only if not already within a block) + if disable_comment.is_none() && has_rules { + if let Some(names) = rule_names { + if matches_eslint_disable_next_line(&data.value, names) { + disable_comment = Some(data.clone()); + enable_comment = Some(data.clone()); + source = Some(SuppressionSource::Eslint); + } + } + } + + // Check for Flow suppression (only if not already within a block) + if flow_suppressions && disable_comment.is_none() && matches_flow_suppression(&data.value) { + disable_comment = Some(data.clone()); + enable_comment = Some(data.clone()); + source = Some(SuppressionSource::Flow); + } + + // Check for eslint-disable (block start) + if has_rules { + if let Some(names) = rule_names { + if matches_eslint_disable(&data.value, names) { + disable_comment = Some(data.clone()); + source = Some(SuppressionSource::Eslint); + } + } + } + + // Check for eslint-enable (block end) + if has_rules { + if let Some(names) = rule_names { + if matches_eslint_enable(&data.value, names) { + if matches!(source, Some(SuppressionSource::Eslint)) { + enable_comment = Some(data.clone()); + } + } + } + } + + // If we have a complete suppression, push it + if disable_comment.is_some() && source.is_some() { + suppression_ranges.push(SuppressionRange { + disable_comment: disable_comment.take().unwrap(), + enable_comment: enable_comment.take(), + source: source.take().unwrap(), + }); + } + } + + suppression_ranges +} + +/// Check if suppression ranges overlap with a function's source range. +/// A suppression affects a function if: +/// 1. The suppression is within the function's body +/// 2. The suppression wraps the function +pub fn filter_suppressions_that_affect_function( + suppressions: &[SuppressionRange], + fn_start: u32, + fn_end: u32, +) -> Vec<&SuppressionRange> { + let mut suppressions_in_scope: Vec<&SuppressionRange> = Vec::new(); + + for suppression in suppressions { + let disable_start = match suppression.disable_comment.start { + Some(s) => s, + None => continue, + }; + + // The suppression is within the function + if disable_start > fn_start + && (suppression.enable_comment.is_none() + || suppression + .enable_comment + .as_ref() + .and_then(|c| c.end) + .map_or(false, |end| end < fn_end)) + { + suppressions_in_scope.push(suppression); + } + + // The suppression wraps the function + if disable_start < fn_start + && (suppression.enable_comment.is_none() + || suppression + .enable_comment + .as_ref() + .and_then(|c| c.end) + .map_or(false, |end| end > fn_end)) + { + suppressions_in_scope.push(suppression); + } + } + + suppressions_in_scope +} + +/// Convert suppression ranges to a CompilerError. +pub fn suppressions_to_compiler_error(suppressions: &[SuppressionRange]) -> CompilerError { + assert!( + !suppressions.is_empty(), + "Expected at least one suppression comment source range" + ); + + let mut error = CompilerError::new(); + + for suppression in suppressions { + let (disable_start, disable_end) = match ( + suppression.disable_comment.start, + suppression.disable_comment.end, + ) { + (Some(s), Some(e)) => (s, e), + _ => continue, + }; + + let (reason, suggestion) = match suppression.source { + SuppressionSource::Eslint => ( + "React Compiler has skipped optimizing this component because one or more React \ + ESLint rules were disabled", + "Remove the ESLint suppression and address the React error", + ), + SuppressionSource::Flow => ( + "React Compiler has skipped optimizing this component because one or more React \ + rule violations were reported by Flow", + "Remove the Flow suppression and address the React error", + ), + }; + + let description = format!( + "React Compiler only works when your components follow all the rules of React, \ + disabling them may result in unexpected or incorrect behavior. Found suppression `{}`", + suppression.disable_comment.value.trim() + ); + + let mut diagnostic = + CompilerDiagnostic::new(ErrorCategory::Suppression, reason, Some(description)); + + diagnostic.suggestions = Some(vec![CompilerSuggestion { + description: suggestion.to_string(), + range: (disable_start as usize, disable_end as usize), + op: CompilerSuggestionOperation::Remove, + text: None, + }]); + + // Add error detail with location info + let loc = suppression.disable_comment.loc.as_ref().map(|l| { + react_compiler_diagnostics::SourceLocation { + start: react_compiler_diagnostics::Position { + line: l.start.line, + column: l.start.column, + index: l.start.index, + }, + end: react_compiler_diagnostics::Position { + line: l.end.line, + column: l.end.column, + index: l.end.index, + }, + } + }); + + diagnostic = diagnostic.with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some("Found React rule suppression".to_string()), + identifier_name: None, + }); + + error.push_diagnostic(diagnostic); + } + + error +} diff --git a/crates/react_compiler/src/entrypoint/validate_source_locations.rs b/crates/react_compiler/src/entrypoint/validate_source_locations.rs new file mode 100644 index 000000000000..a3bff1503ba7 --- /dev/null +++ b/crates/react_compiler/src/entrypoint/validate_source_locations.rs @@ -0,0 +1,1333 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates that important source locations from the original code are +//! preserved in the generated AST. This ensures that Istanbul coverage +//! instrumentation can properly map back to the original source code. +//! +//! This validation is test-only, enabled via `@validateSourceLocations` pragma. +//! +//! Analogous to TS `ValidateSourceLocations.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_ast::{ + common::SourceLocation as AstSourceLocation, + expressions::{ + ArrowFunctionBody, ArrowFunctionExpression, Expression, FunctionExpression, + ObjectExpressionProperty, + }, + patterns::PatternLike, + statements::{ForInOfLeft, ForInit, Statement, VariableDeclaration}, +}; +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory, Position as DiagPosition, + SourceLocation as DiagSourceLocation, +}; +use react_compiler_hir::environment::Environment; +use react_compiler_lowering::FunctionNode; +use react_compiler_reactive_scopes::codegen_reactive_function::CodegenFunction; + +/// Validate that important source locations are preserved in the generated AST. +pub fn validate_source_locations( + func: &FunctionNode<'_>, + codegen: &CodegenFunction, + env: &mut Environment, +) { + // Step 1: Collect important locations from the original source + let important_original = collect_important_original_locations(func); + + // Step 2: Collect all locations from the generated AST + let mut generated = HashMap::>::new(); + collect_generated_from_block(&codegen.body.body, &mut generated); + for outlined in &codegen.outlined { + collect_generated_from_block(&outlined.func.body.body, &mut generated); + } + + // Step 3: Validate that all important locations are preserved + let strict_node_types: HashSet<&str> = + ["VariableDeclaration", "VariableDeclarator", "Identifier"] + .into_iter() + .collect(); + + // Sort entries by source position to match TS output order + // (JS Map preserves insertion order, which is AST traversal order = source + // order) + let mut sorted_entries: Vec<&ImportantLocation> = important_original.values().collect(); + sorted_entries.sort_by(|a, b| { + a.loc + .start + .line + .cmp(&b.loc.start.line) + .then(a.loc.start.column.cmp(&b.loc.start.column)) + // Outer nodes (larger spans) before inner nodes, matching depth-first traversal + .then(b.loc.end.line.cmp(&a.loc.end.line)) + .then(b.loc.end.column.cmp(&a.loc.end.column)) + }); + + for entry in &sorted_entries { + let generated_node_types = generated.get(&entry.key); + + if generated_node_types.is_none() { + // Location is completely missing + let mut node_types_str: Vec<&str> = entry.node_types.iter().copied().collect(); + node_types_str.sort(); + report_missing_location(env, &entry.loc, &node_types_str.join(", ")); + } else { + let generated_node_types = generated_node_types.unwrap(); + // Location exists, check each strict node type + for &node_type in &entry.node_types { + if strict_node_types.contains(node_type) + && !generated_node_types.contains(node_type) + { + // For strict node types, the specific node type must be present. + // Check if any generated node type is also an important original node type. + let has_valid_node_type = generated_node_types + .iter() + .any(|gen_type| entry.node_types.contains(gen_type.as_str())); + + if has_valid_node_type { + report_missing_location(env, &entry.loc, node_type); + } else { + report_wrong_node_type(env, &entry.loc, node_type, generated_node_types); + } + } + } + } + } +} + +// ---- Types ---- + +struct ImportantLocation { + key: String, + loc: AstSourceLocation, + node_types: HashSet<&'static str>, +} + +// ---- Location key ---- + +fn location_key(loc: &AstSourceLocation) -> String { + format!( + "{}:{}-{}:{}", + loc.start.line, loc.start.column, loc.end.line, loc.end.column + ) +} + +// ---- AST to diagnostics SourceLocation conversion ---- + +fn ast_to_diag_loc(loc: &AstSourceLocation) -> DiagSourceLocation { + DiagSourceLocation { + start: DiagPosition { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: DiagPosition { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + } +} + +// ---- Error reporting ---- + +fn report_missing_location(env: &mut Environment, loc: &AstSourceLocation, node_type: &str) { + let diag_loc = ast_to_diag_loc(loc); + env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Todo, + "Important source location missing in generated code", + Some(format!( + "Source location for {} is missing in the generated output. This can cause \ + coverage instrumentation to fail to track this code properly, resulting in \ + inaccurate coverage reports.", + node_type + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: Some(diag_loc), + message: None, + identifier_name: None, + }), + ); +} + +fn report_wrong_node_type( + env: &mut Environment, + loc: &AstSourceLocation, + expected_type: &str, + actual_types: &HashSet, +) { + let diag_loc = ast_to_diag_loc(loc); + let mut actual: Vec<&str> = actual_types.iter().map(|s| s.as_str()).collect(); + actual.sort(); + env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Todo, + "Important source location has wrong node type in generated code", + Some(format!( + "Source location for {} exists in the generated output but with wrong node \ + type(s): {}. This can cause coverage instrumentation to fail to track this code \ + properly, resulting in inaccurate coverage reports.", + expected_type, + actual.join(", ") + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: Some(diag_loc), + message: None, + identifier_name: None, + }), + ); +} + +// ---- Important type checking ---- + +/// Returns the Babel type name if this statement variant is an "important +/// instrumented type". +fn important_statement_type(stmt: &Statement) -> Option<&'static str> { + match stmt { + Statement::ExpressionStatement(_) => Some("ExpressionStatement"), + Statement::BreakStatement(_) => Some("BreakStatement"), + Statement::ContinueStatement(_) => Some("ContinueStatement"), + Statement::ReturnStatement(_) => Some("ReturnStatement"), + Statement::ThrowStatement(_) => Some("ThrowStatement"), + Statement::TryStatement(_) => Some("TryStatement"), + Statement::IfStatement(_) => Some("IfStatement"), + Statement::ForStatement(_) => Some("ForStatement"), + Statement::ForInStatement(_) => Some("ForInStatement"), + Statement::ForOfStatement(_) => Some("ForOfStatement"), + Statement::WhileStatement(_) => Some("WhileStatement"), + Statement::DoWhileStatement(_) => Some("DoWhileStatement"), + Statement::SwitchStatement(_) => Some("SwitchStatement"), + Statement::WithStatement(_) => Some("WithStatement"), + Statement::FunctionDeclaration(_) => Some("FunctionDeclaration"), + Statement::LabeledStatement(_) => Some("LabeledStatement"), + Statement::VariableDeclaration(_) => Some("VariableDeclaration"), + _ => None, + } +} + +/// Returns the Babel type name if this expression variant is an "important +/// instrumented type". +fn important_expression_type(expr: &Expression) -> Option<&'static str> { + match expr { + Expression::ArrowFunctionExpression(_) => Some("ArrowFunctionExpression"), + Expression::FunctionExpression(_) => Some("FunctionExpression"), + Expression::ConditionalExpression(_) => Some("ConditionalExpression"), + Expression::LogicalExpression(_) => Some("LogicalExpression"), + Expression::Identifier(_) => Some("Identifier"), + Expression::AssignmentPattern(_) => Some("AssignmentPattern"), + _ => None, + } +} + +// ---- Manual memoization check ---- + +fn is_manual_memoization(expr: &Expression) -> bool { + if let Expression::CallExpression(call) = expr { + match call.callee.as_ref() { + Expression::Identifier(id) => id.name == "useMemo" || id.name == "useCallback", + Expression::MemberExpression(mem) => { + if let (Expression::Identifier(obj), Expression::Identifier(prop)) = + (mem.object.as_ref(), &*mem.property) + { + obj.name == "React" && (prop.name == "useMemo" || prop.name == "useCallback") + } else { + false + } + } + _ => false, + } + } else { + false + } +} + +// ============================================================================ +// Step 1: Collect important original locations +// ============================================================================ + +fn collect_important_original_locations( + func: &FunctionNode<'_>, +) -> HashMap { + let mut locations = HashMap::new(); + + // Note: TS uses func.traverse() which visits DESCENDANTS only, not the root + // function node itself. So we don't record the root function as important. + match func { + FunctionNode::FunctionDeclaration(f) => { + if let Some(id) = &f.id { + record_important("Identifier", &id.base.loc, &mut locations); + } + for param in &f.params { + collect_original_pattern(param, &mut locations); + } + collect_original_block(&f.body.body, false, &mut locations); + } + FunctionNode::FunctionExpression(f) => { + if let Some(id) = &f.id { + record_important("Identifier", &id.base.loc, &mut locations); + } + for param in &f.params { + collect_original_pattern(param, &mut locations); + } + collect_original_block(&f.body.body, false, &mut locations); + } + FunctionNode::ArrowFunctionExpression(f) => { + for param in &f.params { + collect_original_pattern(param, &mut locations); + } + match f.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + collect_original_block(&block.body, false, &mut locations); + } + ArrowFunctionBody::Expression(expr) => { + collect_original_expression(expr, &mut locations); + } + } + } + } + + locations +} + +fn record_important( + node_type: &'static str, + loc: &Option, + locations: &mut HashMap, +) { + if let Some(loc) = loc { + let key = location_key(loc); + if let Some(existing) = locations.get_mut(&key) { + existing.node_types.insert(node_type); + } else { + let mut node_types = HashSet::new(); + node_types.insert(node_type); + locations.insert( + key.clone(), + ImportantLocation { + key, + loc: loc.clone(), + node_types, + }, + ); + } + } +} + +fn collect_original_block( + stmts: &[Statement], + in_single_return_arrow: bool, + locations: &mut HashMap, +) { + for stmt in stmts { + collect_original_statement(stmt, in_single_return_arrow, locations); + } +} + +fn collect_original_statement( + stmt: &Statement, + in_single_return_arrow: bool, + locations: &mut HashMap, +) { + // Record this statement if it's an important type + if let Some(type_name) = important_statement_type(stmt) { + // Skip return statements inside arrow functions that will be simplified + // to expression body: () => { return expr } -> () => expr + if type_name == "ReturnStatement" && in_single_return_arrow { + if let Statement::ReturnStatement(ret) = stmt { + if ret.argument.is_some() { + // Skip recording, but still recurse into children + if let Some(arg) = &ret.argument { + collect_original_expression(arg, locations); + } + return; + } + } + } + + // Skip manual memoization + if type_name == "ExpressionStatement" { + if let Statement::ExpressionStatement(expr_stmt) = stmt { + if is_manual_memoization(&expr_stmt.expression) { + // Still recurse into children + collect_original_expression(&expr_stmt.expression, locations); + return; + } + } + } + + let base_loc = statement_loc(stmt); + record_important(type_name, base_loc, locations); + } + + // Recurse into children + match stmt { + Statement::BlockStatement(node) => { + collect_original_block(&node.body, false, locations); + } + Statement::ReturnStatement(node) => { + if let Some(arg) = &node.argument { + collect_original_expression(arg, locations); + } + } + Statement::ExpressionStatement(node) => { + collect_original_expression(&node.expression, locations); + } + Statement::IfStatement(node) => { + collect_original_expression(&node.test, locations); + collect_original_statement(&node.consequent, false, locations); + if let Some(alt) = &node.alternate { + collect_original_statement(alt, false, locations); + } + } + Statement::ForStatement(node) => { + if let Some(init) = &node.init { + match init.as_ref() { + ForInit::VariableDeclaration(decl) => { + collect_original_var_declaration(decl, locations); + } + ForInit::Expression(expr) => { + collect_original_expression(expr, locations); + } + } + } + if let Some(test) = &node.test { + collect_original_expression(test, locations); + } + if let Some(update) = &node.update { + collect_original_expression(update, locations); + } + collect_original_statement(&node.body, false, locations); + } + Statement::WhileStatement(node) => { + collect_original_expression(&node.test, locations); + collect_original_statement(&node.body, false, locations); + } + Statement::DoWhileStatement(node) => { + collect_original_statement(&node.body, false, locations); + collect_original_expression(&node.test, locations); + } + Statement::ForInStatement(node) => { + if let ForInOfLeft::Pattern(pat) = node.left.as_ref() { + collect_original_pattern(pat, locations); + } + collect_original_expression(&node.right, locations); + collect_original_statement(&node.body, false, locations); + } + Statement::ForOfStatement(node) => { + if let ForInOfLeft::Pattern(pat) = node.left.as_ref() { + collect_original_pattern(pat, locations); + } + collect_original_expression(&node.right, locations); + collect_original_statement(&node.body, false, locations); + } + Statement::SwitchStatement(node) => { + collect_original_expression(&node.discriminant, locations); + for case in &node.cases { + // SwitchCase is an important type + record_important("SwitchCase", &case.base.loc, locations); + if let Some(test) = &case.test { + collect_original_expression(test, locations); + } + collect_original_block(&case.consequent, false, locations); + } + } + Statement::ThrowStatement(node) => { + collect_original_expression(&node.argument, locations); + } + Statement::TryStatement(node) => { + collect_original_block(&node.block.body, false, locations); + if let Some(handler) = &node.handler { + if let Some(param) = &handler.param { + collect_original_pattern(param, locations); + } + collect_original_block(&handler.body.body, false, locations); + } + if let Some(finalizer) = &node.finalizer { + collect_original_block(&finalizer.body, false, locations); + } + } + Statement::LabeledStatement(node) => { + // Label identifier + record_important("Identifier", &node.label.base.loc, locations); + collect_original_statement(&node.body, false, locations); + } + Statement::VariableDeclaration(node) => { + collect_original_var_declaration(node, locations); + } + Statement::FunctionDeclaration(node) => { + if let Some(id) = &node.id { + record_important("Identifier", &id.base.loc, locations); + } + for param in &node.params { + collect_original_pattern(param, locations); + } + collect_original_block(&node.body.body, false, locations); + } + Statement::WithStatement(node) => { + collect_original_expression(&node.object, locations); + collect_original_statement(&node.body, false, locations); + } + // Non-runtime statements: no children to recurse into + _ => {} + } +} + +fn collect_original_var_declaration( + decl: &VariableDeclaration, + locations: &mut HashMap, +) { + for declarator in &decl.declarations { + // VariableDeclarator is an important type + record_important("VariableDeclarator", &declarator.base.loc, locations); + collect_original_pattern(&declarator.id, locations); + if let Some(init) = &declarator.init { + collect_original_expression(init, locations); + } + } +} + +fn collect_original_expression( + expr: &Expression, + locations: &mut HashMap, +) { + // Record this expression if it's an important type + if let Some(type_name) = important_expression_type(expr) { + // Skip manual memoization + if !is_manual_memoization(expr) { + let base_loc = expression_loc(expr); + record_important(type_name, base_loc, locations); + } + } + + // Recurse into children + match expr { + Expression::Identifier(_) => { + // Already recorded above if important. No children. + } + Expression::CallExpression(node) => { + collect_original_expression(&node.callee, locations); + for arg in &node.arguments { + collect_original_expression(arg, locations); + } + } + Expression::MemberExpression(node) => { + collect_original_expression(&node.object, locations); + if node.computed { + collect_original_expression(&node.property, locations); + } else { + // Non-computed property is an Identifier - record it + if let Expression::Identifier(id) = node.property.as_ref() { + record_important("Identifier", &id.base.loc, locations); + } + } + } + Expression::OptionalCallExpression(node) => { + collect_original_expression(&node.callee, locations); + for arg in &node.arguments { + collect_original_expression(arg, locations); + } + } + Expression::OptionalMemberExpression(node) => { + collect_original_expression(&node.object, locations); + if node.computed { + collect_original_expression(&node.property, locations); + } else if let Expression::Identifier(id) = node.property.as_ref() { + record_important("Identifier", &id.base.loc, locations); + } + } + Expression::BinaryExpression(node) => { + collect_original_expression(&node.left, locations); + collect_original_expression(&node.right, locations); + } + Expression::LogicalExpression(node) => { + collect_original_expression(&node.left, locations); + collect_original_expression(&node.right, locations); + } + Expression::UnaryExpression(node) => { + collect_original_expression(&node.argument, locations); + } + Expression::UpdateExpression(node) => { + collect_original_expression(&node.argument, locations); + } + Expression::ConditionalExpression(node) => { + collect_original_expression(&node.test, locations); + collect_original_expression(&node.consequent, locations); + collect_original_expression(&node.alternate, locations); + } + Expression::AssignmentExpression(node) => { + collect_original_pattern(&node.left, locations); + collect_original_expression(&node.right, locations); + } + Expression::SequenceExpression(node) => { + for e in &node.expressions { + collect_original_expression(e, locations); + } + } + Expression::ArrowFunctionExpression(node) => { + collect_original_arrow_children(node, locations); + } + Expression::FunctionExpression(node) => { + collect_original_fn_expr_children(node, locations); + } + Expression::ObjectExpression(node) => { + for prop in &node.properties { + match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + if p.computed { + collect_original_expression(&p.key, locations); + } else if let Expression::Identifier(id) = p.key.as_ref() { + record_important("Identifier", &id.base.loc, locations); + } + collect_original_expression(&p.value, locations); + } + ObjectExpressionProperty::ObjectMethod(m) => { + // ObjectMethod is an important type + record_important("ObjectMethod", &m.base.loc, locations); + for param in &m.params { + collect_original_pattern(param, locations); + } + collect_original_block(&m.body.body, false, locations); + } + ObjectExpressionProperty::SpreadElement(s) => { + collect_original_expression(&s.argument, locations); + } + } + } + } + Expression::ArrayExpression(node) => { + for elem in node.elements.iter().flatten() { + collect_original_expression(elem, locations); + } + } + Expression::NewExpression(node) => { + collect_original_expression(&node.callee, locations); + for arg in &node.arguments { + collect_original_expression(arg, locations); + } + } + Expression::TemplateLiteral(node) => { + for e in &node.expressions { + collect_original_expression(e, locations); + } + } + Expression::TaggedTemplateExpression(node) => { + collect_original_expression(&node.tag, locations); + for e in &node.quasi.expressions { + collect_original_expression(e, locations); + } + } + Expression::AwaitExpression(node) => { + collect_original_expression(&node.argument, locations); + } + Expression::YieldExpression(node) => { + if let Some(arg) = &node.argument { + collect_original_expression(arg, locations); + } + } + Expression::SpreadElement(node) => { + collect_original_expression(&node.argument, locations); + } + Expression::ParenthesizedExpression(node) => { + collect_original_expression(&node.expression, locations); + } + Expression::AssignmentPattern(node) => { + collect_original_pattern(&node.left, locations); + collect_original_expression(&node.right, locations); + } + Expression::ClassExpression(node) => { + if let Some(sc) = &node.super_class { + collect_original_expression(sc, locations); + } + } + // TS/Flow wrappers — traverse inner expression + Expression::TSAsExpression(node) => { + collect_original_expression(&node.expression, locations); + } + Expression::TSSatisfiesExpression(node) => { + collect_original_expression(&node.expression, locations); + } + Expression::TSNonNullExpression(node) => { + collect_original_expression(&node.expression, locations); + } + Expression::TSTypeAssertion(node) => { + collect_original_expression(&node.expression, locations); + } + Expression::TSInstantiationExpression(node) => { + collect_original_expression(&node.expression, locations); + } + Expression::TypeCastExpression(node) => { + collect_original_expression(&node.expression, locations); + } + // Leaf nodes and JSX + _ => {} + } +} + +fn collect_original_arrow_children( + arrow: &ArrowFunctionExpression, + locations: &mut HashMap, +) { + for param in &arrow.params { + collect_original_pattern(param, locations); + } + match arrow.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + let is_single_return = block.body.len() == 1 && block.directives.is_empty(); + collect_original_block(&block.body, is_single_return, locations); + } + ArrowFunctionBody::Expression(expr) => { + collect_original_expression(expr, locations); + } + } +} + +fn collect_original_fn_expr_children( + func: &FunctionExpression, + locations: &mut HashMap, +) { + if let Some(id) = &func.id { + record_important("Identifier", &id.base.loc, locations); + } + for param in &func.params { + collect_original_pattern(param, locations); + } + collect_original_block(&func.body.body, false, locations); +} + +fn collect_original_pattern( + pattern: &PatternLike, + locations: &mut HashMap, +) { + match pattern { + PatternLike::Identifier(id) => { + record_important("Identifier", &id.base.loc, locations); + } + PatternLike::AssignmentPattern(ap) => { + record_important("AssignmentPattern", &ap.base.loc, locations); + collect_original_pattern(&ap.left, locations); + collect_original_expression(&ap.right, locations); + } + PatternLike::ObjectPattern(op) => { + for prop in &op.properties { + match prop { + react_compiler_ast::patterns::ObjectPatternProperty::ObjectProperty(p) => { + if p.computed { + collect_original_expression(&p.key, locations); + } else if let Expression::Identifier(id) = p.key.as_ref() { + record_important("Identifier", &id.base.loc, locations); + } + collect_original_pattern(&p.value, locations); + } + react_compiler_ast::patterns::ObjectPatternProperty::RestElement(r) => { + collect_original_pattern(&r.argument, locations); + } + } + } + } + PatternLike::ArrayPattern(ap) => { + for elem in ap.elements.iter().flatten() { + collect_original_pattern(elem, locations); + } + } + PatternLike::RestElement(r) => { + collect_original_pattern(&r.argument, locations); + } + PatternLike::MemberExpression(m) => { + collect_original_expression(&Expression::MemberExpression(m.clone()), locations); + } + } +} + +// ---- Helpers to get loc from statement/expression ---- + +fn statement_loc(stmt: &Statement) -> &Option { + match stmt { + Statement::BlockStatement(n) => &n.base.loc, + Statement::ReturnStatement(n) => &n.base.loc, + Statement::IfStatement(n) => &n.base.loc, + Statement::ForStatement(n) => &n.base.loc, + Statement::WhileStatement(n) => &n.base.loc, + Statement::DoWhileStatement(n) => &n.base.loc, + Statement::ForInStatement(n) => &n.base.loc, + Statement::ForOfStatement(n) => &n.base.loc, + Statement::SwitchStatement(n) => &n.base.loc, + Statement::ThrowStatement(n) => &n.base.loc, + Statement::TryStatement(n) => &n.base.loc, + Statement::BreakStatement(n) => &n.base.loc, + Statement::ContinueStatement(n) => &n.base.loc, + Statement::LabeledStatement(n) => &n.base.loc, + Statement::ExpressionStatement(n) => &n.base.loc, + Statement::EmptyStatement(n) => &n.base.loc, + Statement::DebuggerStatement(n) => &n.base.loc, + Statement::WithStatement(n) => &n.base.loc, + Statement::VariableDeclaration(n) => &n.base.loc, + Statement::FunctionDeclaration(n) => &n.base.loc, + Statement::ClassDeclaration(n) => &n.base.loc, + Statement::ImportDeclaration(n) => &n.base.loc, + Statement::ExportNamedDeclaration(n) => &n.base.loc, + Statement::ExportDefaultDeclaration(n) => &n.base.loc, + Statement::ExportAllDeclaration(n) => &n.base.loc, + Statement::TSTypeAliasDeclaration(n) => &n.base.loc, + Statement::TSInterfaceDeclaration(n) => &n.base.loc, + Statement::TSEnumDeclaration(n) => &n.base.loc, + Statement::TSModuleDeclaration(n) => &n.base.loc, + Statement::TSDeclareFunction(n) => &n.base.loc, + Statement::TypeAlias(n) => &n.base.loc, + Statement::OpaqueType(n) => &n.base.loc, + Statement::InterfaceDeclaration(n) => &n.base.loc, + Statement::DeclareVariable(n) => &n.base.loc, + Statement::DeclareFunction(n) => &n.base.loc, + Statement::DeclareClass(n) => &n.base.loc, + Statement::DeclareModule(n) => &n.base.loc, + Statement::DeclareModuleExports(n) => &n.base.loc, + Statement::DeclareExportDeclaration(n) => &n.base.loc, + Statement::DeclareExportAllDeclaration(n) => &n.base.loc, + Statement::DeclareInterface(n) => &n.base.loc, + Statement::DeclareTypeAlias(n) => &n.base.loc, + Statement::DeclareOpaqueType(n) => &n.base.loc, + Statement::EnumDeclaration(n) => &n.base.loc, + } +} + +fn expression_loc(expr: &Expression) -> &Option { + match expr { + Expression::Identifier(n) => &n.base.loc, + Expression::StringLiteral(n) => &n.base.loc, + Expression::NumericLiteral(n) => &n.base.loc, + Expression::BooleanLiteral(n) => &n.base.loc, + Expression::NullLiteral(n) => &n.base.loc, + Expression::BigIntLiteral(n) => &n.base.loc, + Expression::RegExpLiteral(n) => &n.base.loc, + Expression::CallExpression(n) => &n.base.loc, + Expression::MemberExpression(n) => &n.base.loc, + Expression::OptionalCallExpression(n) => &n.base.loc, + Expression::OptionalMemberExpression(n) => &n.base.loc, + Expression::BinaryExpression(n) => &n.base.loc, + Expression::LogicalExpression(n) => &n.base.loc, + Expression::UnaryExpression(n) => &n.base.loc, + Expression::UpdateExpression(n) => &n.base.loc, + Expression::ConditionalExpression(n) => &n.base.loc, + Expression::AssignmentExpression(n) => &n.base.loc, + Expression::SequenceExpression(n) => &n.base.loc, + Expression::ArrowFunctionExpression(n) => &n.base.loc, + Expression::FunctionExpression(n) => &n.base.loc, + Expression::ObjectExpression(n) => &n.base.loc, + Expression::ArrayExpression(n) => &n.base.loc, + Expression::NewExpression(n) => &n.base.loc, + Expression::TemplateLiteral(n) => &n.base.loc, + Expression::TaggedTemplateExpression(n) => &n.base.loc, + Expression::AwaitExpression(n) => &n.base.loc, + Expression::YieldExpression(n) => &n.base.loc, + Expression::SpreadElement(n) => &n.base.loc, + Expression::MetaProperty(n) => &n.base.loc, + Expression::ClassExpression(n) => &n.base.loc, + Expression::PrivateName(n) => &n.base.loc, + Expression::Super(n) => &n.base.loc, + Expression::Import(n) => &n.base.loc, + Expression::ThisExpression(n) => &n.base.loc, + Expression::ParenthesizedExpression(n) => &n.base.loc, + Expression::AssignmentPattern(n) => &n.base.loc, + Expression::JSXElement(n) => &n.base.loc, + Expression::JSXFragment(n) => &n.base.loc, + Expression::TSAsExpression(n) => &n.base.loc, + Expression::TSSatisfiesExpression(n) => &n.base.loc, + Expression::TSNonNullExpression(n) => &n.base.loc, + Expression::TSTypeAssertion(n) => &n.base.loc, + Expression::TSInstantiationExpression(n) => &n.base.loc, + Expression::TypeCastExpression(n) => &n.base.loc, + } +} + +// ============================================================================ +// Step 2: Collect generated locations (ALL node types, not just important ones) +// ============================================================================ + +fn collect_generated_from_block( + stmts: &[Statement], + locations: &mut HashMap>, +) { + for stmt in stmts { + collect_generated_statement(stmt, locations); + } +} + +fn record_generated( + type_name: &str, + loc: &Option, + locations: &mut HashMap>, +) { + if let Some(loc) = loc { + let key = location_key(loc); + locations + .entry(key) + .or_default() + .insert(type_name.to_string()); + } +} + +fn collect_generated_statement(stmt: &Statement, locations: &mut HashMap>) { + // Record this statement's location + let type_name = statement_type_name(stmt); + record_generated(type_name, statement_loc(stmt), locations); + + // Recurse into children (same structure as original, but record ALL types) + match stmt { + Statement::BlockStatement(node) => { + collect_generated_from_block(&node.body, locations); + } + Statement::ReturnStatement(node) => { + if let Some(arg) = &node.argument { + collect_generated_expression(arg, locations); + } + } + Statement::ExpressionStatement(node) => { + collect_generated_expression(&node.expression, locations); + } + Statement::IfStatement(node) => { + collect_generated_expression(&node.test, locations); + collect_generated_statement(&node.consequent, locations); + if let Some(alt) = &node.alternate { + collect_generated_statement(alt, locations); + } + } + Statement::ForStatement(node) => { + if let Some(init) = &node.init { + match init.as_ref() { + ForInit::VariableDeclaration(decl) => { + collect_generated_var_declaration(decl, locations); + } + ForInit::Expression(expr) => { + collect_generated_expression(expr, locations); + } + } + } + if let Some(test) = &node.test { + collect_generated_expression(test, locations); + } + if let Some(update) = &node.update { + collect_generated_expression(update, locations); + } + collect_generated_statement(&node.body, locations); + } + Statement::WhileStatement(node) => { + collect_generated_expression(&node.test, locations); + collect_generated_statement(&node.body, locations); + } + Statement::DoWhileStatement(node) => { + collect_generated_statement(&node.body, locations); + collect_generated_expression(&node.test, locations); + } + Statement::ForInStatement(node) => { + match node.left.as_ref() { + ForInOfLeft::VariableDeclaration(decl) => { + collect_generated_var_declaration(decl, locations); + } + ForInOfLeft::Pattern(pat) => { + collect_generated_pattern(pat, locations); + } + } + collect_generated_expression(&node.right, locations); + collect_generated_statement(&node.body, locations); + } + Statement::ForOfStatement(node) => { + match node.left.as_ref() { + ForInOfLeft::VariableDeclaration(decl) => { + collect_generated_var_declaration(decl, locations); + } + ForInOfLeft::Pattern(pat) => { + collect_generated_pattern(pat, locations); + } + } + collect_generated_expression(&node.right, locations); + collect_generated_statement(&node.body, locations); + } + Statement::SwitchStatement(node) => { + collect_generated_expression(&node.discriminant, locations); + for case in &node.cases { + record_generated("SwitchCase", &case.base.loc, locations); + if let Some(test) = &case.test { + collect_generated_expression(test, locations); + } + collect_generated_from_block(&case.consequent, locations); + } + } + Statement::ThrowStatement(node) => { + collect_generated_expression(&node.argument, locations); + } + Statement::TryStatement(node) => { + collect_generated_from_block(&node.block.body, locations); + if let Some(handler) = &node.handler { + if let Some(param) = &handler.param { + collect_generated_pattern(param, locations); + } + collect_generated_from_block(&handler.body.body, locations); + } + if let Some(finalizer) = &node.finalizer { + collect_generated_from_block(&finalizer.body, locations); + } + } + Statement::LabeledStatement(node) => { + record_generated("Identifier", &node.label.base.loc, locations); + collect_generated_statement(&node.body, locations); + } + Statement::VariableDeclaration(node) => { + collect_generated_var_declaration(node, locations); + } + Statement::FunctionDeclaration(node) => { + if let Some(id) = &node.id { + record_generated("Identifier", &id.base.loc, locations); + } + for param in &node.params { + collect_generated_pattern(param, locations); + } + collect_generated_from_block(&node.body.body, locations); + } + Statement::WithStatement(node) => { + collect_generated_expression(&node.object, locations); + collect_generated_statement(&node.body, locations); + } + Statement::ClassDeclaration(node) => { + if let Some(id) = &node.id { + record_generated("Identifier", &id.base.loc, locations); + } + if let Some(sc) = &node.super_class { + collect_generated_expression(sc, locations); + } + } + _ => {} + } +} + +fn collect_generated_var_declaration( + decl: &VariableDeclaration, + locations: &mut HashMap>, +) { + for declarator in &decl.declarations { + record_generated("VariableDeclarator", &declarator.base.loc, locations); + collect_generated_pattern(&declarator.id, locations); + if let Some(init) = &declarator.init { + collect_generated_expression(init, locations); + } + } +} + +fn collect_generated_expression( + expr: &Expression, + locations: &mut HashMap>, +) { + let type_name = expression_type_name(expr); + record_generated(type_name, expression_loc(expr), locations); + + match expr { + Expression::Identifier(_) => {} + Expression::CallExpression(node) => { + collect_generated_expression(&node.callee, locations); + for arg in &node.arguments { + collect_generated_expression(arg, locations); + } + } + Expression::MemberExpression(node) => { + collect_generated_expression(&node.object, locations); + collect_generated_expression(&node.property, locations); + } + Expression::OptionalCallExpression(node) => { + collect_generated_expression(&node.callee, locations); + for arg in &node.arguments { + collect_generated_expression(arg, locations); + } + } + Expression::OptionalMemberExpression(node) => { + collect_generated_expression(&node.object, locations); + collect_generated_expression(&node.property, locations); + } + Expression::BinaryExpression(node) => { + collect_generated_expression(&node.left, locations); + collect_generated_expression(&node.right, locations); + } + Expression::LogicalExpression(node) => { + collect_generated_expression(&node.left, locations); + collect_generated_expression(&node.right, locations); + } + Expression::UnaryExpression(node) => { + collect_generated_expression(&node.argument, locations); + } + Expression::UpdateExpression(node) => { + collect_generated_expression(&node.argument, locations); + } + Expression::ConditionalExpression(node) => { + collect_generated_expression(&node.test, locations); + collect_generated_expression(&node.consequent, locations); + collect_generated_expression(&node.alternate, locations); + } + Expression::AssignmentExpression(node) => { + collect_generated_pattern(&node.left, locations); + collect_generated_expression(&node.right, locations); + } + Expression::SequenceExpression(node) => { + for e in &node.expressions { + collect_generated_expression(e, locations); + } + } + Expression::ArrowFunctionExpression(node) => { + for param in &node.params { + collect_generated_pattern(param, locations); + } + match node.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + collect_generated_from_block(&block.body, locations); + } + ArrowFunctionBody::Expression(e) => { + collect_generated_expression(e, locations); + } + } + } + Expression::FunctionExpression(node) => { + if let Some(id) = &node.id { + record_generated("Identifier", &id.base.loc, locations); + } + for param in &node.params { + collect_generated_pattern(param, locations); + } + collect_generated_from_block(&node.body.body, locations); + } + Expression::ObjectExpression(node) => { + for prop in &node.properties { + match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + collect_generated_expression(&p.key, locations); + collect_generated_expression(&p.value, locations); + } + ObjectExpressionProperty::ObjectMethod(m) => { + record_generated("ObjectMethod", &m.base.loc, locations); + for param in &m.params { + collect_generated_pattern(param, locations); + } + collect_generated_from_block(&m.body.body, locations); + } + ObjectExpressionProperty::SpreadElement(s) => { + collect_generated_expression(&s.argument, locations); + } + } + } + } + Expression::ArrayExpression(node) => { + for elem in node.elements.iter().flatten() { + collect_generated_expression(elem, locations); + } + } + Expression::NewExpression(node) => { + collect_generated_expression(&node.callee, locations); + for arg in &node.arguments { + collect_generated_expression(arg, locations); + } + } + Expression::TemplateLiteral(node) => { + for e in &node.expressions { + collect_generated_expression(e, locations); + } + } + Expression::TaggedTemplateExpression(node) => { + collect_generated_expression(&node.tag, locations); + for e in &node.quasi.expressions { + collect_generated_expression(e, locations); + } + } + Expression::AwaitExpression(node) => { + collect_generated_expression(&node.argument, locations); + } + Expression::YieldExpression(node) => { + if let Some(arg) = &node.argument { + collect_generated_expression(arg, locations); + } + } + Expression::SpreadElement(node) => { + collect_generated_expression(&node.argument, locations); + } + Expression::ParenthesizedExpression(node) => { + collect_generated_expression(&node.expression, locations); + } + Expression::AssignmentPattern(node) => { + collect_generated_pattern(&node.left, locations); + collect_generated_expression(&node.right, locations); + } + Expression::ClassExpression(node) => { + if let Some(sc) = &node.super_class { + collect_generated_expression(sc, locations); + } + } + Expression::TSAsExpression(node) => { + collect_generated_expression(&node.expression, locations); + } + Expression::TSSatisfiesExpression(node) => { + collect_generated_expression(&node.expression, locations); + } + Expression::TSNonNullExpression(node) => { + collect_generated_expression(&node.expression, locations); + } + Expression::TSTypeAssertion(node) => { + collect_generated_expression(&node.expression, locations); + } + Expression::TSInstantiationExpression(node) => { + collect_generated_expression(&node.expression, locations); + } + Expression::TypeCastExpression(node) => { + collect_generated_expression(&node.expression, locations); + } + // Leaf nodes and JSX + _ => {} + } +} + +fn collect_generated_pattern( + pattern: &PatternLike, + locations: &mut HashMap>, +) { + match pattern { + PatternLike::Identifier(id) => { + record_generated("Identifier", &id.base.loc, locations); + } + PatternLike::AssignmentPattern(ap) => { + record_generated("AssignmentPattern", &ap.base.loc, locations); + collect_generated_pattern(&ap.left, locations); + collect_generated_expression(&ap.right, locations); + } + PatternLike::ObjectPattern(op) => { + record_generated("ObjectPattern", &op.base.loc, locations); + for prop in &op.properties { + match prop { + react_compiler_ast::patterns::ObjectPatternProperty::ObjectProperty(p) => { + record_generated("ObjectProperty", &p.base.loc, locations); + collect_generated_expression(&p.key, locations); + collect_generated_pattern(&p.value, locations); + } + react_compiler_ast::patterns::ObjectPatternProperty::RestElement(r) => { + record_generated("RestElement", &r.base.loc, locations); + collect_generated_pattern(&r.argument, locations); + } + } + } + } + PatternLike::ArrayPattern(ap) => { + record_generated("ArrayPattern", &ap.base.loc, locations); + for elem in ap.elements.iter().flatten() { + collect_generated_pattern(elem, locations); + } + } + PatternLike::RestElement(r) => { + record_generated("RestElement", &r.base.loc, locations); + collect_generated_pattern(&r.argument, locations); + } + PatternLike::MemberExpression(m) => { + record_generated("MemberExpression", &m.base.loc, locations); + collect_generated_expression(&m.object, locations); + collect_generated_expression(&m.property, locations); + } + } +} + +// ---- Type name helpers ---- + +fn statement_type_name(stmt: &Statement) -> &'static str { + match stmt { + Statement::BlockStatement(_) => "BlockStatement", + Statement::ReturnStatement(_) => "ReturnStatement", + Statement::IfStatement(_) => "IfStatement", + Statement::ForStatement(_) => "ForStatement", + Statement::WhileStatement(_) => "WhileStatement", + Statement::DoWhileStatement(_) => "DoWhileStatement", + Statement::ForInStatement(_) => "ForInStatement", + Statement::ForOfStatement(_) => "ForOfStatement", + Statement::SwitchStatement(_) => "SwitchStatement", + Statement::ThrowStatement(_) => "ThrowStatement", + Statement::TryStatement(_) => "TryStatement", + Statement::BreakStatement(_) => "BreakStatement", + Statement::ContinueStatement(_) => "ContinueStatement", + Statement::LabeledStatement(_) => "LabeledStatement", + Statement::ExpressionStatement(_) => "ExpressionStatement", + Statement::EmptyStatement(_) => "EmptyStatement", + Statement::DebuggerStatement(_) => "DebuggerStatement", + Statement::WithStatement(_) => "WithStatement", + Statement::VariableDeclaration(_) => "VariableDeclaration", + Statement::FunctionDeclaration(_) => "FunctionDeclaration", + Statement::ClassDeclaration(_) => "ClassDeclaration", + Statement::ImportDeclaration(_) => "ImportDeclaration", + Statement::ExportNamedDeclaration(_) => "ExportNamedDeclaration", + Statement::ExportDefaultDeclaration(_) => "ExportDefaultDeclaration", + Statement::ExportAllDeclaration(_) => "ExportAllDeclaration", + Statement::TSTypeAliasDeclaration(_) => "TSTypeAliasDeclaration", + Statement::TSInterfaceDeclaration(_) => "TSInterfaceDeclaration", + Statement::TSEnumDeclaration(_) => "TSEnumDeclaration", + Statement::TSModuleDeclaration(_) => "TSModuleDeclaration", + Statement::TSDeclareFunction(_) => "TSDeclareFunction", + Statement::TypeAlias(_) => "TypeAlias", + Statement::OpaqueType(_) => "OpaqueType", + Statement::InterfaceDeclaration(_) => "InterfaceDeclaration", + Statement::DeclareVariable(_) => "DeclareVariable", + Statement::DeclareFunction(_) => "DeclareFunction", + Statement::DeclareClass(_) => "DeclareClass", + Statement::DeclareModule(_) => "DeclareModule", + Statement::DeclareModuleExports(_) => "DeclareModuleExports", + Statement::DeclareExportDeclaration(_) => "DeclareExportDeclaration", + Statement::DeclareExportAllDeclaration(_) => "DeclareExportAllDeclaration", + Statement::DeclareInterface(_) => "DeclareInterface", + Statement::DeclareTypeAlias(_) => "DeclareTypeAlias", + Statement::DeclareOpaqueType(_) => "DeclareOpaqueType", + Statement::EnumDeclaration(_) => "EnumDeclaration", + } +} + +fn expression_type_name(expr: &Expression) -> &'static str { + match expr { + Expression::Identifier(_) => "Identifier", + Expression::StringLiteral(_) => "StringLiteral", + Expression::NumericLiteral(_) => "NumericLiteral", + Expression::BooleanLiteral(_) => "BooleanLiteral", + Expression::NullLiteral(_) => "NullLiteral", + Expression::BigIntLiteral(_) => "BigIntLiteral", + Expression::RegExpLiteral(_) => "RegExpLiteral", + Expression::CallExpression(_) => "CallExpression", + Expression::MemberExpression(_) => "MemberExpression", + Expression::OptionalCallExpression(_) => "OptionalCallExpression", + Expression::OptionalMemberExpression(_) => "OptionalMemberExpression", + Expression::BinaryExpression(_) => "BinaryExpression", + Expression::LogicalExpression(_) => "LogicalExpression", + Expression::UnaryExpression(_) => "UnaryExpression", + Expression::UpdateExpression(_) => "UpdateExpression", + Expression::ConditionalExpression(_) => "ConditionalExpression", + Expression::AssignmentExpression(_) => "AssignmentExpression", + Expression::SequenceExpression(_) => "SequenceExpression", + Expression::ArrowFunctionExpression(_) => "ArrowFunctionExpression", + Expression::FunctionExpression(_) => "FunctionExpression", + Expression::ObjectExpression(_) => "ObjectExpression", + Expression::ArrayExpression(_) => "ArrayExpression", + Expression::NewExpression(_) => "NewExpression", + Expression::TemplateLiteral(_) => "TemplateLiteral", + Expression::TaggedTemplateExpression(_) => "TaggedTemplateExpression", + Expression::AwaitExpression(_) => "AwaitExpression", + Expression::YieldExpression(_) => "YieldExpression", + Expression::SpreadElement(_) => "SpreadElement", + Expression::MetaProperty(_) => "MetaProperty", + Expression::ClassExpression(_) => "ClassExpression", + Expression::PrivateName(_) => "PrivateName", + Expression::Super(_) => "Super", + Expression::Import(_) => "Import", + Expression::ThisExpression(_) => "ThisExpression", + Expression::ParenthesizedExpression(_) => "ParenthesizedExpression", + Expression::AssignmentPattern(_) => "AssignmentPattern", + Expression::JSXElement(_) => "JSXElement", + Expression::JSXFragment(_) => "JSXFragment", + Expression::TSAsExpression(_) => "TSAsExpression", + Expression::TSSatisfiesExpression(_) => "TSSatisfiesExpression", + Expression::TSNonNullExpression(_) => "TSNonNullExpression", + Expression::TSTypeAssertion(_) => "TSTypeAssertion", + Expression::TSInstantiationExpression(_) => "TSInstantiationExpression", + Expression::TypeCastExpression(_) => "TypeCastExpression", + } +} diff --git a/crates/react_compiler/src/fixture_utils.rs b/crates/react_compiler/src/fixture_utils.rs new file mode 100644 index 000000000000..136d7e773f8c --- /dev/null +++ b/crates/react_compiler/src/fixture_utils.rs @@ -0,0 +1,244 @@ +use react_compiler_ast::{ + declarations::{Declaration, ExportDefaultDecl}, + expressions::Expression, + statements::Statement, + File, +}; +use react_compiler_lowering::FunctionNode; + +/// Count the number of top-level functions in an AST file. +/// +/// "Top-level" means: +/// - FunctionDeclaration at program body level +/// - FunctionExpression/ArrowFunctionExpression in a VariableDeclarator at +/// program body level +/// - FunctionDeclaration inside ExportNamedDeclaration +/// - FunctionDeclaration/FunctionExpression/ArrowFunctionExpression inside +/// ExportDefaultDeclaration +/// - VariableDeclaration with function expressions inside +/// ExportNamedDeclaration +/// +/// This matches the TS test binary's traversal behavior. +pub fn count_top_level_functions(ast: &File) -> usize { + let mut count = 0; + for stmt in &ast.program.body { + count += count_functions_in_statement(stmt); + } + count +} + +fn count_functions_in_statement(stmt: &Statement) -> usize { + match stmt { + Statement::FunctionDeclaration(_) => 1, + Statement::VariableDeclaration(var_decl) => { + let mut count = 0; + for declarator in &var_decl.declarations { + if let Some(init) = &declarator.init { + if is_function_expression(init) { + count += 1; + } + } + } + count + } + Statement::ExportNamedDeclaration(export) => { + if let Some(decl) = &export.declaration { + match decl.as_ref() { + Declaration::FunctionDeclaration(_) => 1, + Declaration::VariableDeclaration(var_decl) => { + let mut count = 0; + for declarator in &var_decl.declarations { + if let Some(init) = &declarator.init { + if is_function_expression(init) { + count += 1; + } + } + } + count + } + _ => 0, + } + } else { + 0 + } + } + Statement::ExportDefaultDeclaration(export) => match export.declaration.as_ref() { + ExportDefaultDecl::FunctionDeclaration(_) => 1, + ExportDefaultDecl::Expression(expr) => { + if is_function_expression(expr) { + 1 + } else { + 0 + } + } + _ => 0, + }, + // Expression statements with function expressions (uncommon but possible) + Statement::ExpressionStatement(expr_stmt) => { + if is_function_expression(&expr_stmt.expression) { + 1 + } else { + 0 + } + } + _ => 0, + } +} + +fn is_function_expression(expr: &Expression) -> bool { + matches!( + expr, + Expression::FunctionExpression(_) | Expression::ArrowFunctionExpression(_) + ) +} + +/// Extract the nth top-level function from an AST file as a `FunctionNode`. +/// Also returns the inferred name (e.g. from a variable declarator). +/// Returns None if function_index is out of bounds. +pub fn extract_function( + ast: &File, + function_index: usize, +) -> Option<(FunctionNode<'_>, Option<&str>)> { + let mut index = 0usize; + + for stmt in &ast.program.body { + match stmt { + Statement::FunctionDeclaration(func_decl) => { + if index == function_index { + let name = func_decl.id.as_ref().map(|id| id.name.as_str()); + return Some((FunctionNode::FunctionDeclaration(func_decl), name)); + } + index += 1; + } + Statement::VariableDeclaration(var_decl) => { + for declarator in &var_decl.declarations { + if let Some(init) = &declarator.init { + match init.as_ref() { + Expression::FunctionExpression(func) => { + if index == function_index { + let name = match &declarator.id { + react_compiler_ast::patterns::PatternLike::Identifier( + ident, + ) => Some(ident.name.as_str()), + _ => func.id.as_ref().map(|id| id.name.as_str()), + }; + return Some((FunctionNode::FunctionExpression(func), name)); + } + index += 1; + } + Expression::ArrowFunctionExpression(arrow) => { + if index == function_index { + let name = match &declarator.id { + react_compiler_ast::patterns::PatternLike::Identifier( + ident, + ) => Some(ident.name.as_str()), + _ => None, + }; + return Some(( + FunctionNode::ArrowFunctionExpression(arrow), + name, + )); + } + index += 1; + } + _ => {} + } + } + } + } + Statement::ExportNamedDeclaration(export) => { + if let Some(decl) = &export.declaration { + match decl.as_ref() { + Declaration::FunctionDeclaration(func_decl) => { + if index == function_index { + let name = func_decl.id.as_ref().map(|id| id.name.as_str()); + return Some((FunctionNode::FunctionDeclaration(func_decl), name)); + } + index += 1; + } + Declaration::VariableDeclaration(var_decl) => { + for declarator in &var_decl.declarations { + if let Some(init) = &declarator.init { + match init.as_ref() { + Expression::FunctionExpression(func) => { + if index == function_index { + let name = match &declarator.id { + react_compiler_ast::patterns::PatternLike::Identifier(ident) => Some(ident.name.as_str()), + _ => func.id.as_ref().map(|id| id.name.as_str()), + }; + return Some(( + FunctionNode::FunctionExpression(func), + name, + )); + } + index += 1; + } + Expression::ArrowFunctionExpression(arrow) => { + if index == function_index { + let name = match &declarator.id { + react_compiler_ast::patterns::PatternLike::Identifier(ident) => Some(ident.name.as_str()), + _ => None, + }; + return Some(( + FunctionNode::ArrowFunctionExpression(arrow), + name, + )); + } + index += 1; + } + _ => {} + } + } + } + } + _ => {} + } + } + } + Statement::ExportDefaultDeclaration(export) => match export.declaration.as_ref() { + ExportDefaultDecl::FunctionDeclaration(func_decl) => { + if index == function_index { + let name = func_decl.id.as_ref().map(|id| id.name.as_str()); + return Some((FunctionNode::FunctionDeclaration(func_decl), name)); + } + index += 1; + } + ExportDefaultDecl::Expression(expr) => match expr.as_ref() { + Expression::FunctionExpression(func) => { + if index == function_index { + let name = func.id.as_ref().map(|id| id.name.as_str()); + return Some((FunctionNode::FunctionExpression(func), name)); + } + index += 1; + } + Expression::ArrowFunctionExpression(arrow) => { + if index == function_index { + return Some((FunctionNode::ArrowFunctionExpression(arrow), None)); + } + index += 1; + } + _ => {} + }, + _ => {} + }, + Statement::ExpressionStatement(expr_stmt) => match expr_stmt.expression.as_ref() { + Expression::FunctionExpression(func) => { + if index == function_index { + let name = func.id.as_ref().map(|id| id.name.as_str()); + return Some((FunctionNode::FunctionExpression(func), name)); + } + index += 1; + } + Expression::ArrowFunctionExpression(arrow) => { + if index == function_index { + return Some((FunctionNode::ArrowFunctionExpression(arrow), None)); + } + index += 1; + } + _ => {} + }, + _ => {} + } + } + None +} diff --git a/crates/react_compiler/src/lib.rs b/crates/react_compiler/src/lib.rs new file mode 100644 index 000000000000..9baa40d09e2a --- /dev/null +++ b/crates/react_compiler/src/lib.rs @@ -0,0 +1,14 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod debug_print; +pub mod entrypoint; +pub mod fixture_utils; +pub mod timing; + +// Re-export from new crates for backwards compatibility +pub use react_compiler_diagnostics; +pub use react_compiler_hir as hir; +pub use react_compiler_hir::{self, environment}; +pub use react_compiler_lowering::lower; diff --git a/crates/react_compiler/src/timing.rs b/crates/react_compiler/src/timing.rs new file mode 100644 index 000000000000..5228c3021050 --- /dev/null +++ b/crates/react_compiler/src/timing.rs @@ -0,0 +1,77 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Simple timing accumulator for profiling compiler passes. +//! +//! Uses `std::time::Instant` unconditionally (cheap when not storing results). +//! Controlled by the `__profiling` flag in plugin options. + +use std::time::{Duration, Instant}; + +use serde::Serialize; + +/// A single timing entry recording how long a named phase took. +#[derive(Debug, Clone, Serialize)] +pub struct TimingEntry { + pub name: String, + pub duration_us: u64, +} + +/// Accumulates timing data for compiler passes. +pub struct TimingData { + enabled: bool, + entries: Vec<(String, Duration)>, + current_name: Option, + current_start: Option, +} + +impl TimingData { + /// Create a new TimingData. If `enabled` is false, all operations are + /// no-ops. + pub fn new(enabled: bool) -> Self { + Self { + enabled, + entries: Vec::new(), + current_name: None, + current_start: None, + } + } + + /// Start timing a named phase. Stops any currently running phase first. + pub fn start(&mut self, name: &str) { + if !self.enabled { + return; + } + // Stop any currently running phase + if self.current_start.is_some() { + self.stop(); + } + self.current_name = Some(name.to_string()); + self.current_start = Some(Instant::now()); + } + + /// Stop the currently running phase and record its duration. + pub fn stop(&mut self) { + if !self.enabled { + return; + } + if let (Some(name), Some(start)) = (self.current_name.take(), self.current_start.take()) { + self.entries.push((name, start.elapsed())); + } + } + + /// Consume this TimingData and return the collected entries. + pub fn into_entries(mut self) -> Vec { + // Stop any still-running phase + self.stop(); + self.entries + .into_iter() + .map(|(name, duration)| TimingEntry { + name, + duration_us: duration.as_micros() as u64, + }) + .collect() + } +} diff --git a/crates/react_compiler_ast/Cargo.toml b/crates/react_compiler_ast/Cargo.toml new file mode 100644 index 000000000000..0ce8e04ba62d --- /dev/null +++ b/crates/react_compiler_ast/Cargo.toml @@ -0,0 +1,16 @@ +[package] +description = "Vendored React Compiler AST from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_ast" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +indexmap = { workspace = true, features = ["serde"] } + +[dev-dependencies] +walkdir = "2" +similar = "2" diff --git a/crates/react_compiler_ast/src/common.rs b/crates/react_compiler_ast/src/common.rs new file mode 100644 index 000000000000..349a21dfcf9b --- /dev/null +++ b/crates/react_compiler_ast/src/common.rs @@ -0,0 +1,109 @@ +use serde::{Deserialize, Serialize}; + +/// Custom deserializer that distinguishes "field absent" from "field: null". +/// - JSON field absent → `None` (via `#[serde(default)]`) +/// - JSON field `null` → `Some(Value::Null)` +/// - JSON field with value → `Some(value)` +/// +/// Use with `#[serde(default, skip_serializing_if = "Option::is_none", +/// deserialize_with = "nullable_value")]` +pub fn nullable_value<'de, D>(deserializer: D) -> Result>, D::Error> +where + D: serde::Deserializer<'de>, +{ + let value = serde_json::Value::deserialize(deserializer)?; + Ok(Some(Box::new(value))) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Position { + pub line: u32, + pub column: u32, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub index: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SourceLocation { + pub start: Position, + pub end: Position, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub filename: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "identifierName" + )] + pub identifier_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Comment { + CommentBlock(CommentData), + CommentLine(CommentData), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CommentData { + pub value: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub start: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub end: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub loc: Option, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct BaseNode { + // NOTE: When creating AST nodes for code generation output, use + // `BaseNode::typed("NodeTypeName")` instead of `BaseNode::default()` + // to ensure the "type" field is emitted during serialization. + /// The node type string (e.g. "BlockStatement"). + /// When deserialized through a `#[serde(tag = "type")]` enum, the enum + /// consumes the "type" field so this defaults to None. When deserialized + /// directly, this captures the "type" field for round-trip fidelity. + #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")] + pub node_type: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub start: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub end: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub loc: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub range: Option<(u32, u32)>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub extra: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "leadingComments" + )] + pub leading_comments: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "innerComments" + )] + pub inner_comments: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "trailingComments" + )] + pub trailing_comments: Option>, +} + +impl BaseNode { + /// Create a BaseNode with the given type name. + /// Use this when creating AST nodes for code generation to ensure the + /// `"type"` field is present in serialized output. + pub fn typed(type_name: &str) -> Self { + Self { + node_type: Some(type_name.to_string()), + ..Default::default() + } + } +} diff --git a/crates/react_compiler_ast/src/declarations.rs b/crates/react_compiler_ast/src/declarations.rs new file mode 100644 index 000000000000..f4361a2934d4 --- /dev/null +++ b/crates/react_compiler_ast/src/declarations.rs @@ -0,0 +1,464 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + common::BaseNode, + expressions::{Expression, Identifier}, + literals::StringLiteral, +}; + +/// Union of Declaration types that can appear in export declarations +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Declaration { + FunctionDeclaration(crate::statements::FunctionDeclaration), + ClassDeclaration(crate::statements::ClassDeclaration), + VariableDeclaration(crate::statements::VariableDeclaration), + TSTypeAliasDeclaration(TSTypeAliasDeclaration), + TSInterfaceDeclaration(TSInterfaceDeclaration), + TSEnumDeclaration(TSEnumDeclaration), + TSModuleDeclaration(TSModuleDeclaration), + TSDeclareFunction(TSDeclareFunction), + TypeAlias(TypeAlias), + OpaqueType(OpaqueType), + InterfaceDeclaration(InterfaceDeclaration), + EnumDeclaration(EnumDeclaration), +} + +/// The declaration/expression that can appear in `export default ` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ExportDefaultDecl { + FunctionDeclaration(crate::statements::FunctionDeclaration), + ClassDeclaration(crate::statements::ClassDeclaration), + #[serde(untagged)] + Expression(Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub specifiers: Vec, + pub source: StringLiteral, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "importKind" + )] + pub import_kind: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub assertions: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub attributes: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ImportKind { + Value, + Type, + Typeof, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ImportSpecifier { + ImportSpecifier(ImportSpecifierData), + ImportDefaultSpecifier(ImportDefaultSpecifierData), + ImportNamespaceSpecifier(ImportNamespaceSpecifierData), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportSpecifierData { + #[serde(flatten)] + pub base: BaseNode, + pub local: Identifier, + pub imported: ModuleExportName, + #[serde(default, rename = "importKind")] + pub import_kind: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportDefaultSpecifierData { + #[serde(flatten)] + pub base: BaseNode, + pub local: Identifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportNamespaceSpecifierData { + #[serde(flatten)] + pub base: BaseNode, + pub local: Identifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportAttribute { + #[serde(flatten)] + pub base: BaseNode, + pub key: Identifier, + pub value: StringLiteral, +} + +/// Identifier or StringLiteral used as module export names +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ModuleExportName { + Identifier(Identifier), + StringLiteral(StringLiteral), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportNamedDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub declaration: Option>, + pub specifiers: Vec, + pub source: Option, + #[serde(default, rename = "exportKind")] + pub export_kind: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub assertions: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub attributes: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ExportKind { + Value, + Type, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ExportSpecifier { + ExportSpecifier(ExportSpecifierData), + ExportDefaultSpecifier(ExportDefaultSpecifierData), + ExportNamespaceSpecifier(ExportNamespaceSpecifierData), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportSpecifierData { + #[serde(flatten)] + pub base: BaseNode, + pub local: ModuleExportName, + pub exported: ModuleExportName, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "exportKind" + )] + pub export_kind: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportDefaultSpecifierData { + #[serde(flatten)] + pub base: BaseNode, + pub exported: Identifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportNamespaceSpecifierData { + #[serde(flatten)] + pub base: BaseNode, + pub exported: ModuleExportName, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportDefaultDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub declaration: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "exportKind" + )] + pub export_kind: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExportAllDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub source: StringLiteral, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "exportKind" + )] + pub export_kind: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub assertions: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub attributes: Option>, +} + +// TypeScript declarations (pass-through via serde_json::Value for bodies) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSTypeAliasDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + #[serde(rename = "typeAnnotation")] + pub type_annotation: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSInterfaceDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub body: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub extends: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSEnumDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub members: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, + #[serde(default, skip_serializing_if = "Option::is_none", rename = "const")] + pub is_const: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSModuleDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Box, + pub body: Box, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub global: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSDeclareFunction { + #[serde(flatten)] + pub base: BaseNode, + pub id: Option, + pub params: Vec, + #[serde(default, skip_serializing_if = "Option::is_none", rename = "async")] + pub is_async: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub generator: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "returnType" + )] + pub return_type: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +// Flow declarations (pass-through) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypeAlias { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub right: Box, + #[serde(default, rename = "typeParameters")] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpaqueType { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + #[serde(rename = "supertype")] + pub supertype: Option>, + pub impltype: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InterfaceDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub body: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub extends: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mixins: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub implements: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareVariable { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareFunction { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub predicate: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareClass { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub body: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub extends: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mixins: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub implements: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareModule { + #[serde(flatten)] + pub base: BaseNode, + pub id: Box, + pub body: Box, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub kind: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareModuleExports { + #[serde(flatten)] + pub base: BaseNode, + #[serde(rename = "typeAnnotation")] + pub type_annotation: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareExportDeclaration { + #[serde(flatten)] + pub base: BaseNode, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declaration: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub specifiers: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub source: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareExportAllDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub source: StringLiteral, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareInterface { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub body: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub extends: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mixins: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub implements: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareTypeAlias { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub right: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DeclareOpaqueType { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub supertype: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub impltype: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnumDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, + pub body: Box, +} diff --git a/crates/react_compiler_ast/src/expressions.rs b/crates/react_compiler_ast/src/expressions.rs new file mode 100644 index 000000000000..e822b412ef99 --- /dev/null +++ b/crates/react_compiler_ast/src/expressions.rs @@ -0,0 +1,542 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + common::BaseNode, + jsx::{JSXElement, JSXFragment}, + literals::*, + operators::*, + patterns::{AssignmentPattern, PatternLike}, + statements::BlockStatement, +}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Identifier { + #[serde(flatten)] + pub base: BaseNode, + pub name: String, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeAnnotation" + )] + pub type_annotation: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub optional: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Expression { + Identifier(Identifier), + StringLiteral(StringLiteral), + NumericLiteral(NumericLiteral), + BooleanLiteral(BooleanLiteral), + NullLiteral(NullLiteral), + BigIntLiteral(BigIntLiteral), + RegExpLiteral(RegExpLiteral), + CallExpression(CallExpression), + MemberExpression(MemberExpression), + OptionalCallExpression(OptionalCallExpression), + OptionalMemberExpression(OptionalMemberExpression), + BinaryExpression(BinaryExpression), + LogicalExpression(LogicalExpression), + UnaryExpression(UnaryExpression), + UpdateExpression(UpdateExpression), + ConditionalExpression(ConditionalExpression), + AssignmentExpression(AssignmentExpression), + SequenceExpression(SequenceExpression), + ArrowFunctionExpression(ArrowFunctionExpression), + FunctionExpression(FunctionExpression), + ObjectExpression(ObjectExpression), + ArrayExpression(ArrayExpression), + NewExpression(NewExpression), + TemplateLiteral(TemplateLiteral), + TaggedTemplateExpression(TaggedTemplateExpression), + AwaitExpression(AwaitExpression), + YieldExpression(YieldExpression), + SpreadElement(SpreadElement), + MetaProperty(MetaProperty), + ClassExpression(ClassExpression), + PrivateName(PrivateName), + Super(Super), + Import(Import), + ThisExpression(ThisExpression), + ParenthesizedExpression(ParenthesizedExpression), + // JSX expressions + JSXElement(Box), + JSXFragment(JSXFragment), + // Pattern (can appear in expression position in error recovery) + AssignmentPattern(AssignmentPattern), + // TypeScript expressions + TSAsExpression(TSAsExpression), + TSSatisfiesExpression(TSSatisfiesExpression), + TSNonNullExpression(TSNonNullExpression), + TSTypeAssertion(TSTypeAssertion), + TSInstantiationExpression(TSInstantiationExpression), + // Flow expressions + TypeCastExpression(TypeCastExpression), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CallExpression { + #[serde(flatten)] + pub base: BaseNode, + pub callee: Box, + pub arguments: Vec, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeArguments" + )] + pub type_arguments: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub optional: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemberExpression { + #[serde(flatten)] + pub base: BaseNode, + pub object: Box, + pub property: Box, + pub computed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OptionalCallExpression { + #[serde(flatten)] + pub base: BaseNode, + pub callee: Box, + pub arguments: Vec, + pub optional: bool, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeArguments" + )] + pub type_arguments: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OptionalMemberExpression { + #[serde(flatten)] + pub base: BaseNode, + pub object: Box, + pub property: Box, + pub computed: bool, + pub optional: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BinaryExpression { + #[serde(flatten)] + pub base: BaseNode, + pub operator: BinaryOperator, + pub left: Box, + pub right: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LogicalExpression { + #[serde(flatten)] + pub base: BaseNode, + pub operator: LogicalOperator, + pub left: Box, + pub right: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UnaryExpression { + #[serde(flatten)] + pub base: BaseNode, + pub operator: UnaryOperator, + pub prefix: bool, + pub argument: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UpdateExpression { + #[serde(flatten)] + pub base: BaseNode, + pub operator: UpdateOperator, + pub argument: Box, + pub prefix: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConditionalExpression { + #[serde(flatten)] + pub base: BaseNode, + pub test: Box, + pub consequent: Box, + pub alternate: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AssignmentExpression { + #[serde(flatten)] + pub base: BaseNode, + pub operator: AssignmentOperator, + pub left: Box, + pub right: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SequenceExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expressions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArrowFunctionExpression { + #[serde(flatten)] + pub base: BaseNode, + pub params: Vec, + pub body: Box, + #[serde(default)] + pub id: Option, + #[serde(default)] + pub generator: bool, + #[serde(default, rename = "async")] + pub is_async: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub expression: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "returnType" + )] + pub return_type: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none", rename = "predicate")] + pub predicate: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ArrowFunctionBody { + BlockStatement(BlockStatement), + #[serde(untagged)] + Expression(Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionExpression { + #[serde(flatten)] + pub base: BaseNode, + pub params: Vec, + pub body: BlockStatement, + #[serde(default)] + pub id: Option, + #[serde(default)] + pub generator: bool, + #[serde(default, rename = "async")] + pub is_async: bool, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "returnType" + )] + pub return_type: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ObjectExpression { + #[serde(flatten)] + pub base: BaseNode, + pub properties: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ObjectExpressionProperty { + ObjectProperty(ObjectProperty), + ObjectMethod(ObjectMethod), + SpreadElement(SpreadElement), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ObjectProperty { + #[serde(flatten)] + pub base: BaseNode, + pub key: Box, + pub value: Box, + pub computed: bool, + pub shorthand: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub method: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ObjectMethod { + #[serde(flatten)] + pub base: BaseNode, + pub method: bool, + pub kind: ObjectMethodKind, + pub key: Box, + pub params: Vec, + pub body: BlockStatement, + pub computed: bool, + #[serde(default)] + pub id: Option, + #[serde(default)] + pub generator: bool, + #[serde(default, rename = "async")] + pub is_async: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "returnType" + )] + pub return_type: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ObjectMethodKind { + Method, + Get, + Set, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArrayExpression { + #[serde(flatten)] + pub base: BaseNode, + pub elements: Vec>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NewExpression { + #[serde(flatten)] + pub base: BaseNode, + pub callee: Box, + pub arguments: Vec, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + deserialize_with = "crate::common::nullable_value", + rename = "typeArguments" + )] + pub type_arguments: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TemplateLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub quasis: Vec, + pub expressions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaggedTemplateExpression { + #[serde(flatten)] + pub base: BaseNode, + pub tag: Box, + pub quasi: TemplateLiteral, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AwaitExpression { + #[serde(flatten)] + pub base: BaseNode, + pub argument: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct YieldExpression { + #[serde(flatten)] + pub base: BaseNode, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub argument: Option>, + pub delegate: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SpreadElement { + #[serde(flatten)] + pub base: BaseNode, + pub argument: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MetaProperty { + #[serde(flatten)] + pub base: BaseNode, + pub meta: Identifier, + pub property: Identifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClassExpression { + #[serde(flatten)] + pub base: BaseNode, + #[serde(default)] + pub id: Option, + #[serde(rename = "superClass")] + pub super_class: Option>, + pub body: ClassBody, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "implements" + )] + pub implements: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "superTypeParameters" + )] + pub super_type_parameters: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClassBody { + #[serde(flatten)] + pub base: BaseNode, + pub body: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PrivateName { + #[serde(flatten)] + pub base: BaseNode, + pub id: Identifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Super { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Import { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThisExpression { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParenthesizedExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, +} + +// TypeScript expression nodes (pass-through with serde_json::Value for type +// args) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSAsExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, + #[serde(rename = "typeAnnotation")] + pub type_annotation: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSSatisfiesExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, + #[serde(rename = "typeAnnotation")] + pub type_annotation: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSNonNullExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSTypeAssertion { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, + #[serde(rename = "typeAnnotation")] + pub type_annotation: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TSInstantiationExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, + #[serde(rename = "typeParameters")] + pub type_parameters: Box, +} + +// Flow expression nodes +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TypeCastExpression { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, + #[serde(rename = "typeAnnotation")] + pub type_annotation: Box, +} diff --git a/crates/react_compiler_ast/src/jsx.rs b/crates/react_compiler_ast/src/jsx.rs new file mode 100644 index 000000000000..155312875d98 --- /dev/null +++ b/crates/react_compiler_ast/src/jsx.rs @@ -0,0 +1,187 @@ +use serde::{Deserialize, Serialize}; + +use crate::{common::BaseNode, expressions::Expression, literals::StringLiteral}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXElement { + #[serde(flatten)] + pub base: BaseNode, + #[serde(rename = "openingElement")] + pub opening_element: JSXOpeningElement, + #[serde(rename = "closingElement")] + pub closing_element: Option, + pub children: Vec, + #[serde( + rename = "selfClosing", + default, + skip_serializing_if = "Option::is_none" + )] + pub self_closing: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXFragment { + #[serde(flatten)] + pub base: BaseNode, + #[serde(rename = "openingFragment")] + pub opening_fragment: JSXOpeningFragment, + #[serde(rename = "closingFragment")] + pub closing_fragment: JSXClosingFragment, + pub children: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXOpeningElement { + #[serde(flatten)] + pub base: BaseNode, + pub name: JSXElementName, + pub attributes: Vec, + #[serde(rename = "selfClosing")] + pub self_closing: bool, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXClosingElement { + #[serde(flatten)] + pub base: BaseNode, + pub name: JSXElementName, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXOpeningFragment { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXClosingFragment { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXElementName { + JSXIdentifier(JSXIdentifier), + JSXMemberExpression(JSXMemberExpression), + JSXNamespacedName(JSXNamespacedName), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXChild { + JSXElement(Box), + JSXFragment(JSXFragment), + JSXExpressionContainer(JSXExpressionContainer), + JSXSpreadChild(JSXSpreadChild), + JSXText(JSXText), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXAttributeItem { + JSXAttribute(JSXAttribute), + JSXSpreadAttribute(JSXSpreadAttribute), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXAttribute { + #[serde(flatten)] + pub base: BaseNode, + pub name: JSXAttributeName, + pub value: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXAttributeName { + JSXIdentifier(JSXIdentifier), + JSXNamespacedName(JSXNamespacedName), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXAttributeValue { + StringLiteral(StringLiteral), + JSXExpressionContainer(JSXExpressionContainer), + JSXElement(Box), + JSXFragment(JSXFragment), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXSpreadAttribute { + #[serde(flatten)] + pub base: BaseNode, + pub argument: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXExpressionContainer { + #[serde(flatten)] + pub base: BaseNode, + pub expression: JSXExpressionContainerExpr, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXExpressionContainerExpr { + JSXEmptyExpression(JSXEmptyExpression), + #[serde(untagged)] + Expression(Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXSpreadChild { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXText { + #[serde(flatten)] + pub base: BaseNode, + pub value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXEmptyExpression { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXIdentifier { + #[serde(flatten)] + pub base: BaseNode, + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXMemberExpression { + #[serde(flatten)] + pub base: BaseNode, + pub object: Box, + pub property: JSXIdentifier, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum JSXMemberExprObject { + JSXIdentifier(JSXIdentifier), + JSXMemberExpression(Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JSXNamespacedName { + #[serde(flatten)] + pub base: BaseNode, + pub namespace: JSXIdentifier, + pub name: JSXIdentifier, +} diff --git a/crates/react_compiler_ast/src/lib.rs b/crates/react_compiler_ast/src/lib.rs new file mode 100644 index 000000000000..bed34373e55e --- /dev/null +++ b/crates/react_compiler_ast/src/lib.rs @@ -0,0 +1,66 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod common; +pub mod declarations; +pub mod expressions; +pub mod jsx; +pub mod literals; +pub mod operators; +pub mod patterns; +pub mod scope; +pub mod statements; +pub mod visitor; + +use serde::{Deserialize, Serialize}; + +use crate::{ + common::{BaseNode, Comment}, + statements::{Directive, Statement}, +}; + +/// The root type returned by @babel/parser +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct File { + #[serde(flatten)] + pub base: BaseNode, + pub program: Program, + #[serde(default)] + pub comments: Vec, + #[serde(default)] + pub errors: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Program { + #[serde(flatten)] + pub base: BaseNode, + pub body: Vec, + #[serde(default)] + pub directives: Vec, + #[serde(rename = "sourceType")] + pub source_type: SourceType, + #[serde(default)] + pub interpreter: Option, + #[serde( + rename = "sourceFile", + default, + skip_serializing_if = "Option::is_none" + )] + pub source_file: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SourceType { + Module, + Script, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InterpreterDirective { + #[serde(flatten)] + pub base: BaseNode, + pub value: String, +} diff --git a/crates/react_compiler_ast/src/literals.rs b/crates/react_compiler_ast/src/literals.rs new file mode 100644 index 000000000000..7ba142e32aee --- /dev/null +++ b/crates/react_compiler_ast/src/literals.rs @@ -0,0 +1,60 @@ +use serde::{Deserialize, Serialize}; + +use crate::common::BaseNode; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StringLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NumericLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub value: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BooleanLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub value: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct NullLiteral { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BigIntLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegExpLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub pattern: String, + pub flags: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TemplateElement { + #[serde(flatten)] + pub base: BaseNode, + pub value: TemplateElementValue, + pub tail: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TemplateElementValue { + pub raw: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cooked: Option, +} diff --git a/crates/react_compiler_ast/src/operators.rs b/crates/react_compiler_ast/src/operators.rs new file mode 100644 index 000000000000..d52dbb49128c --- /dev/null +++ b/crates/react_compiler_ast/src/operators.rs @@ -0,0 +1,125 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum BinaryOperator { + #[serde(rename = "+")] + Add, + #[serde(rename = "-")] + Sub, + #[serde(rename = "*")] + Mul, + #[serde(rename = "/")] + Div, + #[serde(rename = "%")] + Rem, + #[serde(rename = "**")] + Exp, + #[serde(rename = "==")] + Eq, + #[serde(rename = "===")] + StrictEq, + #[serde(rename = "!=")] + Neq, + #[serde(rename = "!==")] + StrictNeq, + #[serde(rename = "<")] + Lt, + #[serde(rename = "<=")] + Lte, + #[serde(rename = ">")] + Gt, + #[serde(rename = ">=")] + Gte, + #[serde(rename = "<<")] + Shl, + #[serde(rename = ">>")] + Shr, + #[serde(rename = ">>>")] + UShr, + #[serde(rename = "|")] + BitOr, + #[serde(rename = "^")] + BitXor, + #[serde(rename = "&")] + BitAnd, + #[serde(rename = "in")] + In, + #[serde(rename = "instanceof")] + Instanceof, + #[serde(rename = "|>")] + Pipeline, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum LogicalOperator { + #[serde(rename = "||")] + Or, + #[serde(rename = "&&")] + And, + #[serde(rename = "??")] + NullishCoalescing, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum UnaryOperator { + #[serde(rename = "-")] + Neg, + #[serde(rename = "+")] + Plus, + #[serde(rename = "!")] + Not, + #[serde(rename = "~")] + BitNot, + #[serde(rename = "typeof")] + TypeOf, + #[serde(rename = "void")] + Void, + #[serde(rename = "delete")] + Delete, + #[serde(rename = "throw")] + Throw, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum UpdateOperator { + #[serde(rename = "++")] + Increment, + #[serde(rename = "--")] + Decrement, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AssignmentOperator { + #[serde(rename = "=")] + Assign, + #[serde(rename = "+=")] + AddAssign, + #[serde(rename = "-=")] + SubAssign, + #[serde(rename = "*=")] + MulAssign, + #[serde(rename = "/=")] + DivAssign, + #[serde(rename = "%=")] + RemAssign, + #[serde(rename = "**=")] + ExpAssign, + #[serde(rename = "<<=")] + ShlAssign, + #[serde(rename = ">>=")] + ShrAssign, + #[serde(rename = ">>>=")] + UShrAssign, + #[serde(rename = "|=")] + BitOrAssign, + #[serde(rename = "^=")] + BitXorAssign, + #[serde(rename = "&=")] + BitAndAssign, + #[serde(rename = "||=")] + OrAssign, + #[serde(rename = "&&=")] + AndAssign, + #[serde(rename = "??=")] + NullishAssign, +} diff --git a/crates/react_compiler_ast/src/patterns.rs b/crates/react_compiler_ast/src/patterns.rs new file mode 100644 index 000000000000..e9772020ee35 --- /dev/null +++ b/crates/react_compiler_ast/src/patterns.rs @@ -0,0 +1,103 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + common::BaseNode, + expressions::{Expression, Identifier}, +}; + +/// Covers assignment targets and patterns. +/// In Babel, LVal includes Identifier, MemberExpression, ObjectPattern, +/// ArrayPattern, RestElement, AssignmentPattern. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum PatternLike { + Identifier(Identifier), + ObjectPattern(ObjectPattern), + ArrayPattern(ArrayPattern), + AssignmentPattern(AssignmentPattern), + RestElement(RestElement), + // Expressions can appear in pattern positions (e.g., MemberExpression as LVal) + MemberExpression(crate::expressions::MemberExpression), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ObjectPattern { + #[serde(flatten)] + pub base: BaseNode, + pub properties: Vec, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeAnnotation" + )] + pub type_annotation: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ObjectPatternProperty { + ObjectProperty(ObjectPatternProp), + RestElement(RestElement), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ObjectPatternProp { + #[serde(flatten)] + pub base: BaseNode, + pub key: Box, + pub value: Box, + pub computed: bool, + pub shorthand: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub method: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArrayPattern { + #[serde(flatten)] + pub base: BaseNode, + pub elements: Vec>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeAnnotation" + )] + pub type_annotation: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AssignmentPattern { + #[serde(flatten)] + pub base: BaseNode, + pub left: Box, + pub right: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeAnnotation" + )] + pub type_annotation: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RestElement { + #[serde(flatten)] + pub base: BaseNode, + pub argument: Box, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeAnnotation" + )] + pub type_annotation: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, +} diff --git a/crates/react_compiler_ast/src/scope.rs b/crates/react_compiler_ast/src/scope.rs new file mode 100644 index 000000000000..c17b35698522 --- /dev/null +++ b/crates/react_compiler_ast/src/scope.rs @@ -0,0 +1,169 @@ +use std::collections::HashMap; + +use indexmap::IndexMap; +use serde::{Deserialize, Serialize}; + +/// Identifies a scope in the scope table. Copy-able, used as an index. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct ScopeId(pub u32); + +/// Identifies a binding (variable declaration) in the binding table. Copy-able, +/// used as an index. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct BindingId(pub u32); + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ScopeData { + pub id: ScopeId, + pub parent: Option, + pub kind: ScopeKind, + /// Bindings declared directly in this scope, keyed by name. + /// Maps to BindingId for lookup in the binding table. + pub bindings: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ScopeKind { + Program, + Function, + Block, + #[serde(rename = "for")] + For, + Class, + Switch, + Catch, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct BindingData { + pub id: BindingId, + pub name: String, + pub kind: BindingKind, + /// The scope this binding is declared in. + pub scope: ScopeId, + /// The type of the declaration AST node (e.g., "FunctionDeclaration", + /// "VariableDeclarator"). Used by the compiler to distinguish function + /// declarations from variable declarations during hoisting. + pub declaration_type: String, + /// The start offset of the binding's declaration identifier. + /// Used to distinguish declaration sites from references in + /// `reference_to_binding`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declaration_start: Option, + /// For import bindings: the source module and import details. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub import: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum BindingKind { + Var, + Let, + Const, + Param, + /// Import bindings (import declarations). + Module, + /// Function declarations (hoisted). + Hoisted, + /// Other local bindings (class declarations, etc.). + Local, + /// Binding kind not recognized by the serializer. + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportBindingData { + /// The module specifier string (e.g., "react" in `import {useState} from + /// 'react'`). + pub source: String, + pub kind: ImportBindingKind, + /// For named imports: the imported name (e.g., "bar" in `import {bar as + /// baz} from 'foo'`). None for default and namespace imports. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub imported: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ImportBindingKind { + Default, + Named, + Namespace, +} + +/// Complete scope information for a program. Stored separately from the AST +/// and linked via position-based lookup maps. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ScopeInfo { + /// All scopes, indexed by ScopeId. scopes[id.0] gives the ScopeData for + /// that scope. + pub scopes: Vec, + /// All bindings, indexed by BindingId. bindings[id.0] gives the + /// BindingData. + pub bindings: Vec, + + /// Maps an AST node's start offset to the scope it creates. + pub node_to_scope: HashMap, + + /// Maps an Identifier AST node's start offset to the binding it resolves + /// to. Only present for identifiers that resolve to a binding (not + /// globals). Uses IndexMap to preserve insertion order (source order + /// from serialization). + pub reference_to_binding: IndexMap, + + /// The program-level (module) scope. Always scopes[0]. + pub program_scope: ScopeId, +} + +impl ScopeInfo { + /// Look up a binding by name starting from the given scope, + /// walking up the parent chain. Returns None for globals. + pub fn get_binding(&self, scope_id: ScopeId, name: &str) -> Option { + let mut current = Some(scope_id); + while let Some(id) = current { + let scope = &self.scopes[id.0 as usize]; + if let Some(&binding_id) = scope.bindings.get(name) { + return Some(binding_id); + } + current = scope.parent; + } + None + } + + /// Look up the binding for an identifier reference by its AST node start + /// offset. Returns None for globals/unresolved references. + pub fn resolve_reference(&self, identifier_start: u32) -> Option<&BindingData> { + self.reference_to_binding + .get(&identifier_start) + .map(|id| &self.bindings[id.0 as usize]) + } + + /// Look up a binding by name in the scope that contains the identifier at + /// `start`. Used as a fallback when position-based lookup + /// (`resolve_reference`) returns a binding whose name doesn't match -- + /// e.g., when Babel's Flow component transform creates multiple params + /// with the same start position. + pub fn resolve_reference_by_name(&self, name: &str, start: u32) -> Option<&BindingData> { + // Find which scope contains this position + let scope_id = self.resolve_reference(start).map(|b| b.scope)?; + // Look for a binding with the matching name in that scope + let scope = &self.scopes[scope_id.0 as usize]; + scope + .bindings + .get(name) + .map(|id| &self.bindings[id.0 as usize]) + } + + /// Get all bindings declared in a scope (for hoisting iteration). + pub fn scope_bindings(&self, scope_id: ScopeId) -> impl Iterator { + self.scopes[scope_id.0 as usize] + .bindings + .values() + .map(|id| &self.bindings[id.0 as usize]) + } +} diff --git a/crates/react_compiler_ast/src/statements.rs b/crates/react_compiler_ast/src/statements.rs new file mode 100644 index 000000000000..9c31be6d92ac --- /dev/null +++ b/crates/react_compiler_ast/src/statements.rs @@ -0,0 +1,363 @@ +use serde::{Deserialize, Serialize}; + +use crate::{ + common::BaseNode, + expressions::{Expression, Identifier}, + patterns::PatternLike, +}; + +fn is_false(v: &bool) -> bool { + !v +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Statement { + // Statements + BlockStatement(BlockStatement), + ReturnStatement(ReturnStatement), + IfStatement(IfStatement), + ForStatement(ForStatement), + WhileStatement(WhileStatement), + DoWhileStatement(DoWhileStatement), + ForInStatement(ForInStatement), + ForOfStatement(ForOfStatement), + SwitchStatement(SwitchStatement), + ThrowStatement(ThrowStatement), + TryStatement(TryStatement), + BreakStatement(BreakStatement), + ContinueStatement(ContinueStatement), + LabeledStatement(LabeledStatement), + ExpressionStatement(ExpressionStatement), + EmptyStatement(EmptyStatement), + DebuggerStatement(DebuggerStatement), + WithStatement(WithStatement), + // Declarations are also statements + VariableDeclaration(VariableDeclaration), + FunctionDeclaration(FunctionDeclaration), + ClassDeclaration(ClassDeclaration), + // Import/export declarations + ImportDeclaration(crate::declarations::ImportDeclaration), + ExportNamedDeclaration(crate::declarations::ExportNamedDeclaration), + ExportDefaultDeclaration(crate::declarations::ExportDefaultDeclaration), + ExportAllDeclaration(crate::declarations::ExportAllDeclaration), + // TypeScript declarations + TSTypeAliasDeclaration(crate::declarations::TSTypeAliasDeclaration), + TSInterfaceDeclaration(crate::declarations::TSInterfaceDeclaration), + TSEnumDeclaration(crate::declarations::TSEnumDeclaration), + TSModuleDeclaration(crate::declarations::TSModuleDeclaration), + TSDeclareFunction(crate::declarations::TSDeclareFunction), + // Flow declarations + TypeAlias(crate::declarations::TypeAlias), + OpaqueType(crate::declarations::OpaqueType), + InterfaceDeclaration(crate::declarations::InterfaceDeclaration), + DeclareVariable(crate::declarations::DeclareVariable), + DeclareFunction(crate::declarations::DeclareFunction), + DeclareClass(crate::declarations::DeclareClass), + DeclareModule(crate::declarations::DeclareModule), + DeclareModuleExports(crate::declarations::DeclareModuleExports), + DeclareExportDeclaration(crate::declarations::DeclareExportDeclaration), + DeclareExportAllDeclaration(crate::declarations::DeclareExportAllDeclaration), + DeclareInterface(crate::declarations::DeclareInterface), + DeclareTypeAlias(crate::declarations::DeclareTypeAlias), + DeclareOpaqueType(crate::declarations::DeclareOpaqueType), + EnumDeclaration(crate::declarations::EnumDeclaration), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BlockStatement { + #[serde(flatten)] + pub base: BaseNode, + pub body: Vec, + #[serde(default)] + pub directives: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Directive { + #[serde(flatten)] + pub base: BaseNode, + pub value: DirectiveLiteral, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DirectiveLiteral { + #[serde(flatten)] + pub base: BaseNode, + pub value: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReturnStatement { + #[serde(flatten)] + pub base: BaseNode, + pub argument: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ExpressionStatement { + #[serde(flatten)] + pub base: BaseNode, + pub expression: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IfStatement { + #[serde(flatten)] + pub base: BaseNode, + pub test: Box, + pub consequent: Box, + pub alternate: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForStatement { + #[serde(flatten)] + pub base: BaseNode, + pub init: Option>, + pub test: Option>, + pub update: Option>, + pub body: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ForInit { + VariableDeclaration(VariableDeclaration), + #[serde(untagged)] + Expression(Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WhileStatement { + #[serde(flatten)] + pub base: BaseNode, + pub test: Box, + pub body: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DoWhileStatement { + #[serde(flatten)] + pub base: BaseNode, + pub test: Box, + pub body: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForInStatement { + #[serde(flatten)] + pub base: BaseNode, + pub left: Box, + pub right: Box, + pub body: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ForOfStatement { + #[serde(flatten)] + pub base: BaseNode, + pub left: Box, + pub right: Box, + pub body: Box, + #[serde(default, rename = "await")] + pub is_await: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ForInOfLeft { + VariableDeclaration(VariableDeclaration), + #[serde(untagged)] + Pattern(Box), +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SwitchStatement { + #[serde(flatten)] + pub base: BaseNode, + pub discriminant: Box, + pub cases: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SwitchCase { + #[serde(flatten)] + pub base: BaseNode, + pub test: Option>, + pub consequent: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ThrowStatement { + #[serde(flatten)] + pub base: BaseNode, + pub argument: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TryStatement { + #[serde(flatten)] + pub base: BaseNode, + pub block: BlockStatement, + pub handler: Option, + pub finalizer: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CatchClause { + #[serde(flatten)] + pub base: BaseNode, + pub param: Option, + pub body: BlockStatement, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BreakStatement { + #[serde(flatten)] + pub base: BaseNode, + pub label: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContinueStatement { + #[serde(flatten)] + pub base: BaseNode, + pub label: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LabeledStatement { + #[serde(flatten)] + pub base: BaseNode, + pub label: Identifier, + pub body: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmptyStatement { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DebuggerStatement { + #[serde(flatten)] + pub base: BaseNode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WithStatement { + #[serde(flatten)] + pub base: BaseNode, + pub object: Box, + pub body: Box, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VariableDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub declarations: Vec, + pub kind: VariableDeclarationKind, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum VariableDeclarationKind { + Var, + Let, + Const, + Using, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct VariableDeclarator { + #[serde(flatten)] + pub base: BaseNode, + pub id: PatternLike, + pub init: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub definite: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Option, + pub params: Vec, + pub body: BlockStatement, + #[serde(default)] + pub generator: bool, + #[serde(default, rename = "async")] + pub is_async: bool, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "returnType" + )] + pub return_type: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none", rename = "predicate")] + pub predicate: Option>, + /// Set by the Hermes parser for Flow `component Foo(...) { ... }` syntax + #[serde( + default, + skip_serializing_if = "is_false", + rename = "__componentDeclaration" + )] + pub component_declaration: bool, + /// Set by the Hermes parser for Flow `hook useFoo(...) { ... }` syntax + #[serde( + default, + skip_serializing_if = "is_false", + rename = "__hookDeclaration" + )] + pub hook_declaration: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClassDeclaration { + #[serde(flatten)] + pub base: BaseNode, + pub id: Option, + #[serde(rename = "superClass")] + pub super_class: Option>, + pub body: crate::expressions::ClassBody, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub decorators: Option>, + #[serde(default, skip_serializing_if = "Option::is_none", rename = "abstract")] + pub is_abstract: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub declare: Option, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "implements" + )] + pub implements: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "superTypeParameters" + )] + pub super_type_parameters: Option>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + rename = "typeParameters" + )] + pub type_parameters: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mixins: Option>, +} diff --git a/crates/react_compiler_ast/src/visitor.rs b/crates/react_compiler_ast/src/visitor.rs new file mode 100644 index 000000000000..2b181e6c1f06 --- /dev/null +++ b/crates/react_compiler_ast/src/visitor.rs @@ -0,0 +1,1400 @@ +//! AST visitor with automatic scope tracking. +//! +//! Provides a [`Visitor`] trait with enter/leave hooks for specific node types, +//! and an [`AstWalker`] that traverses the AST while tracking the active scope +//! via the scope tree's `node_to_scope` map. + +use crate::{ + declarations::*, + expressions::*, + jsx::*, + patterns::*, + scope::{ScopeId, ScopeInfo}, + statements::*, + Program, +}; + +/// Trait for visiting Babel AST nodes. All methods default to no-ops. +/// Override specific methods to intercept nodes of interest. +/// +/// The `'ast` lifetime ties visitor hooks to the AST being walked, allowing +/// visitors to store references into the AST (e.g., for deferred processing). +/// +/// The `scope_stack` parameter provides the current scope context during +/// traversal. The active scope is `scope_stack.last()`. +pub trait Visitor<'ast> { + /// Controls whether the walker recurses into function/arrow/method bodies. + /// Returns `true` by default. Override to `false` to skip function bodies + /// (similar to Babel's `path.skip()` in traverse visitors). + /// + /// When `false`, the walker still calls `enter_*` / `leave_*` for functions + /// but does not walk their params or body. + fn traverse_function_bodies(&self) -> bool { + true + } + + fn enter_function_declaration( + &mut self, + _node: &'ast FunctionDeclaration, + _scope_stack: &[ScopeId], + ) { + } + fn leave_function_declaration( + &mut self, + _node: &'ast FunctionDeclaration, + _scope_stack: &[ScopeId], + ) { + } + fn enter_function_expression( + &mut self, + _node: &'ast FunctionExpression, + _scope_stack: &[ScopeId], + ) { + } + fn leave_function_expression( + &mut self, + _node: &'ast FunctionExpression, + _scope_stack: &[ScopeId], + ) { + } + fn enter_arrow_function_expression( + &mut self, + _node: &'ast ArrowFunctionExpression, + _scope_stack: &[ScopeId], + ) { + } + fn leave_arrow_function_expression( + &mut self, + _node: &'ast ArrowFunctionExpression, + _scope_stack: &[ScopeId], + ) { + } + fn enter_object_method(&mut self, _node: &'ast ObjectMethod, _scope_stack: &[ScopeId]) {} + fn leave_object_method(&mut self, _node: &'ast ObjectMethod, _scope_stack: &[ScopeId]) {} + fn enter_assignment_expression( + &mut self, + _node: &'ast AssignmentExpression, + _scope_stack: &[ScopeId], + ) { + } + fn enter_update_expression(&mut self, _node: &'ast UpdateExpression, _scope_stack: &[ScopeId]) { + } + fn enter_identifier(&mut self, _node: &'ast Identifier, _scope_stack: &[ScopeId]) {} + fn enter_jsx_identifier(&mut self, _node: &'ast JSXIdentifier, _scope_stack: &[ScopeId]) {} + fn enter_jsx_opening_element( + &mut self, + _node: &'ast JSXOpeningElement, + _scope_stack: &[ScopeId], + ) { + } + fn leave_jsx_opening_element( + &mut self, + _node: &'ast JSXOpeningElement, + _scope_stack: &[ScopeId], + ) { + } + + fn enter_variable_declarator( + &mut self, + _node: &'ast VariableDeclarator, + _scope_stack: &[ScopeId], + ) { + } + fn leave_variable_declarator( + &mut self, + _node: &'ast VariableDeclarator, + _scope_stack: &[ScopeId], + ) { + } + + fn enter_call_expression(&mut self, _node: &'ast CallExpression, _scope_stack: &[ScopeId]) {} + fn leave_call_expression(&mut self, _node: &'ast CallExpression, _scope_stack: &[ScopeId]) {} + + /// Called when the walker enters a loop expression context (while.test, + /// do-while.test, for-in.right, for-of.right). Functions found in these + /// positions are treated as non-program-scope by Babel, even though the + /// walker doesn't push a scope for them. + fn enter_loop_expression(&mut self) {} + fn leave_loop_expression(&mut self) {} +} + +/// Walks the AST while tracking scope context via `node_to_scope`. +pub struct AstWalker<'a> { + scope_info: &'a ScopeInfo, + scope_stack: Vec, + /// Depth counter for loop/iteration expression positions (while.test, + /// do-while.test, for-in.right, for-of.right). These positions are + /// NOT inside a scope in the walker's model, but Babel's scope analysis + /// treats them as non-program-scope. Visitors can check this via + /// `in_loop_expression_depth()` to implement Babel-compatible scope checks. + loop_expression_depth: usize, +} + +impl<'a> AstWalker<'a> { + pub fn new(scope_info: &'a ScopeInfo) -> Self { + AstWalker { + scope_info, + scope_stack: Vec::new(), + loop_expression_depth: 0, + } + } + + /// Create a walker with an initial scope already on the stack. + pub fn with_initial_scope(scope_info: &'a ScopeInfo, initial_scope: ScopeId) -> Self { + AstWalker { + scope_info, + scope_stack: vec![initial_scope], + loop_expression_depth: 0, + } + } + + pub fn scope_stack(&self) -> &[ScopeId] { + &self.scope_stack + } + + /// Returns the current loop-expression depth. Non-zero when the walker is + /// inside a loop's test/right expression (while.test, do-while.test, + /// for-in.right, for-of.right). Visitors can use this to implement + /// Babel-compatible scope checks in 'all' compilation mode. + pub fn loop_expression_depth(&self) -> usize { + self.loop_expression_depth + } + + /// Try to push a scope for a node. Returns true if a scope was pushed. + fn try_push_scope(&mut self, start: Option) -> bool { + if let Some(start) = start { + if let Some(&scope_id) = self.scope_info.node_to_scope.get(&start) { + self.scope_stack.push(scope_id); + return true; + } + } + false + } + + // ---- Public walk methods ---- + + pub fn walk_program<'ast>(&mut self, v: &mut impl Visitor<'ast>, node: &'ast Program) { + let pushed = self.try_push_scope(node.base.start); + for stmt in &node.body { + self.walk_statement(v, stmt); + } + if pushed { + self.scope_stack.pop(); + } + } + + pub fn walk_block_statement<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + node: &'ast BlockStatement, + ) { + let pushed = self.try_push_scope(node.base.start); + for stmt in &node.body { + self.walk_statement(v, stmt); + } + if pushed { + self.scope_stack.pop(); + } + } + + pub fn walk_statement<'ast>(&mut self, v: &mut impl Visitor<'ast>, stmt: &'ast Statement) { + match stmt { + Statement::BlockStatement(node) => self.walk_block_statement(v, node), + Statement::ReturnStatement(node) => { + if let Some(arg) = &node.argument { + self.walk_expression(v, arg); + } + } + Statement::ExpressionStatement(node) => { + self.walk_expression(v, &node.expression); + } + Statement::IfStatement(node) => { + self.walk_expression(v, &node.test); + self.walk_statement(v, &node.consequent); + if let Some(alt) = &node.alternate { + self.walk_statement(v, alt); + } + } + Statement::ForStatement(node) => { + let pushed = self.try_push_scope(node.base.start); + if let Some(init) = &node.init { + match init.as_ref() { + ForInit::VariableDeclaration(decl) => { + self.walk_variable_declaration(v, decl) + } + ForInit::Expression(expr) => self.walk_expression(v, expr), + } + } + if let Some(test) = &node.test { + self.walk_expression(v, test); + } + if let Some(update) = &node.update { + self.walk_expression(v, update); + } + self.walk_statement(v, &node.body); + if pushed { + self.scope_stack.pop(); + } + } + Statement::WhileStatement(node) => { + self.loop_expression_depth += 1; + v.enter_loop_expression(); + self.walk_expression(v, &node.test); + v.leave_loop_expression(); + self.loop_expression_depth -= 1; + self.walk_statement(v, &node.body); + } + Statement::DoWhileStatement(node) => { + self.walk_statement(v, &node.body); + self.loop_expression_depth += 1; + v.enter_loop_expression(); + self.walk_expression(v, &node.test); + v.leave_loop_expression(); + self.loop_expression_depth -= 1; + } + Statement::ForInStatement(node) => { + let pushed = self.try_push_scope(node.base.start); + self.walk_for_in_of_left(v, &node.left); + self.loop_expression_depth += 1; + v.enter_loop_expression(); + self.walk_expression(v, &node.right); + v.leave_loop_expression(); + self.loop_expression_depth -= 1; + self.walk_statement(v, &node.body); + if pushed { + self.scope_stack.pop(); + } + } + Statement::ForOfStatement(node) => { + let pushed = self.try_push_scope(node.base.start); + self.walk_for_in_of_left(v, &node.left); + self.loop_expression_depth += 1; + v.enter_loop_expression(); + self.walk_expression(v, &node.right); + v.leave_loop_expression(); + self.loop_expression_depth -= 1; + self.walk_statement(v, &node.body); + if pushed { + self.scope_stack.pop(); + } + } + Statement::SwitchStatement(node) => { + let pushed = self.try_push_scope(node.base.start); + self.walk_expression(v, &node.discriminant); + for case in &node.cases { + if let Some(test) = &case.test { + self.walk_expression(v, test); + } + for consequent in &case.consequent { + self.walk_statement(v, consequent); + } + } + if pushed { + self.scope_stack.pop(); + } + } + Statement::ThrowStatement(node) => { + self.walk_expression(v, &node.argument); + } + Statement::TryStatement(node) => { + self.walk_block_statement(v, &node.block); + if let Some(handler) = &node.handler { + let pushed = self.try_push_scope(handler.base.start); + if let Some(param) = &handler.param { + self.walk_pattern(v, param); + } + self.walk_block_statement(v, &handler.body); + if pushed { + self.scope_stack.pop(); + } + } + if let Some(finalizer) = &node.finalizer { + self.walk_block_statement(v, finalizer); + } + } + Statement::LabeledStatement(node) => { + self.walk_statement(v, &node.body); + } + Statement::VariableDeclaration(node) => { + self.walk_variable_declaration(v, node); + } + Statement::FunctionDeclaration(node) => { + self.walk_function_declaration_inner(v, node); + } + Statement::ClassDeclaration(node) => { + if let Some(sc) = &node.super_class { + self.walk_expression(v, sc); + } + } + Statement::WithStatement(node) => { + self.walk_expression(v, &node.object); + self.walk_statement(v, &node.body); + } + Statement::ExportNamedDeclaration(node) => { + if let Some(decl) = &node.declaration { + self.walk_declaration(v, decl); + } + } + Statement::ExportDefaultDeclaration(node) => { + self.walk_export_default_decl(v, &node.declaration); + } + // No runtime expressions to traverse + Statement::BreakStatement(_) + | Statement::ContinueStatement(_) + | Statement::EmptyStatement(_) + | Statement::DebuggerStatement(_) + | Statement::ImportDeclaration(_) + | Statement::ExportAllDeclaration(_) + | Statement::TSTypeAliasDeclaration(_) + | Statement::TSInterfaceDeclaration(_) + | Statement::TSEnumDeclaration(_) + | Statement::TSModuleDeclaration(_) + | Statement::TSDeclareFunction(_) + | Statement::TypeAlias(_) + | Statement::OpaqueType(_) + | Statement::InterfaceDeclaration(_) + | Statement::DeclareVariable(_) + | Statement::DeclareFunction(_) + | Statement::DeclareClass(_) + | Statement::DeclareModule(_) + | Statement::DeclareModuleExports(_) + | Statement::DeclareExportDeclaration(_) + | Statement::DeclareExportAllDeclaration(_) + | Statement::DeclareInterface(_) + | Statement::DeclareTypeAlias(_) + | Statement::DeclareOpaqueType(_) + | Statement::EnumDeclaration(_) => {} + } + } + + pub fn walk_expression<'ast>(&mut self, v: &mut impl Visitor<'ast>, expr: &'ast Expression) { + match expr { + Expression::Identifier(node) => { + v.enter_identifier(node, &self.scope_stack); + } + Expression::CallExpression(node) => { + v.enter_call_expression(node, &self.scope_stack); + self.walk_expression(v, &node.callee); + for arg in &node.arguments { + self.walk_expression(v, arg); + } + v.leave_call_expression(node, &self.scope_stack); + } + Expression::MemberExpression(node) => { + self.walk_expression(v, &node.object); + if node.computed { + self.walk_expression(v, &node.property); + } + } + Expression::OptionalCallExpression(node) => { + self.walk_expression(v, &node.callee); + for arg in &node.arguments { + self.walk_expression(v, arg); + } + } + Expression::OptionalMemberExpression(node) => { + self.walk_expression(v, &node.object); + if node.computed { + self.walk_expression(v, &node.property); + } + } + Expression::BinaryExpression(node) => { + self.walk_expression(v, &node.left); + self.walk_expression(v, &node.right); + } + Expression::LogicalExpression(node) => { + self.walk_expression(v, &node.left); + self.walk_expression(v, &node.right); + } + Expression::UnaryExpression(node) => { + self.walk_expression(v, &node.argument); + } + Expression::UpdateExpression(node) => { + v.enter_update_expression(node, &self.scope_stack); + self.walk_expression(v, &node.argument); + } + Expression::ConditionalExpression(node) => { + self.walk_expression(v, &node.test); + self.walk_expression(v, &node.consequent); + self.walk_expression(v, &node.alternate); + } + Expression::AssignmentExpression(node) => { + v.enter_assignment_expression(node, &self.scope_stack); + self.walk_pattern(v, &node.left); + self.walk_expression(v, &node.right); + } + Expression::SequenceExpression(node) => { + for expr in &node.expressions { + self.walk_expression(v, expr); + } + } + Expression::ArrowFunctionExpression(node) => { + let pushed = self.try_push_scope(node.base.start); + v.enter_arrow_function_expression(node, &self.scope_stack); + if v.traverse_function_bodies() { + for param in &node.params { + self.walk_pattern(v, param); + } + match node.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + self.walk_block_statement(v, block); + } + ArrowFunctionBody::Expression(expr) => { + self.walk_expression(v, expr); + } + } + } + v.leave_arrow_function_expression(node, &self.scope_stack); + if pushed { + self.scope_stack.pop(); + } + } + Expression::FunctionExpression(node) => { + let pushed = self.try_push_scope(node.base.start); + v.enter_function_expression(node, &self.scope_stack); + if v.traverse_function_bodies() { + for param in &node.params { + self.walk_pattern(v, param); + } + self.walk_block_statement(v, &node.body); + } + v.leave_function_expression(node, &self.scope_stack); + if pushed { + self.scope_stack.pop(); + } + } + Expression::ObjectExpression(node) => { + for prop in &node.properties { + self.walk_object_expression_property(v, prop); + } + } + Expression::ArrayExpression(node) => { + for element in &node.elements { + if let Some(el) = element { + self.walk_expression(v, el); + } + } + } + Expression::NewExpression(node) => { + self.walk_expression(v, &node.callee); + for arg in &node.arguments { + self.walk_expression(v, arg); + } + } + Expression::TemplateLiteral(node) => { + for expr in &node.expressions { + self.walk_expression(v, expr); + } + } + Expression::TaggedTemplateExpression(node) => { + self.walk_expression(v, &node.tag); + for expr in &node.quasi.expressions { + self.walk_expression(v, expr); + } + } + Expression::AwaitExpression(node) => { + self.walk_expression(v, &node.argument); + } + Expression::YieldExpression(node) => { + if let Some(arg) = &node.argument { + self.walk_expression(v, arg); + } + } + Expression::SpreadElement(node) => { + self.walk_expression(v, &node.argument); + } + Expression::ParenthesizedExpression(node) => { + self.walk_expression(v, &node.expression); + } + Expression::AssignmentPattern(node) => { + self.walk_pattern(v, &node.left); + self.walk_expression(v, &node.right); + } + Expression::ClassExpression(node) => { + if let Some(sc) = &node.super_class { + self.walk_expression(v, sc); + } + } + // JSX + Expression::JSXElement(node) => self.walk_jsx_element(v, node), + Expression::JSXFragment(node) => self.walk_jsx_fragment(v, node), + // TS/Flow wrappers - traverse inner expression + Expression::TSAsExpression(node) => self.walk_expression(v, &node.expression), + Expression::TSSatisfiesExpression(node) => self.walk_expression(v, &node.expression), + Expression::TSNonNullExpression(node) => self.walk_expression(v, &node.expression), + Expression::TSTypeAssertion(node) => self.walk_expression(v, &node.expression), + Expression::TSInstantiationExpression(node) => { + self.walk_expression(v, &node.expression) + } + Expression::TypeCastExpression(node) => self.walk_expression(v, &node.expression), + // Leaf nodes + Expression::StringLiteral(_) + | Expression::NumericLiteral(_) + | Expression::BooleanLiteral(_) + | Expression::NullLiteral(_) + | Expression::BigIntLiteral(_) + | Expression::RegExpLiteral(_) + | Expression::MetaProperty(_) + | Expression::PrivateName(_) + | Expression::Super(_) + | Expression::Import(_) + | Expression::ThisExpression(_) => {} + } + } + + pub fn walk_pattern<'ast>(&mut self, v: &mut impl Visitor<'ast>, pat: &'ast PatternLike) { + match pat { + PatternLike::Identifier(node) => { + v.enter_identifier(node, &self.scope_stack); + } + PatternLike::ObjectPattern(node) => { + for prop in &node.properties { + match prop { + ObjectPatternProperty::ObjectProperty(p) => { + if p.computed { + self.walk_expression(v, &p.key); + } + self.walk_pattern(v, &p.value); + } + ObjectPatternProperty::RestElement(p) => { + self.walk_pattern(v, &p.argument); + } + } + } + } + PatternLike::ArrayPattern(node) => { + for element in &node.elements { + if let Some(el) = element { + self.walk_pattern(v, el); + } + } + } + PatternLike::AssignmentPattern(node) => { + self.walk_pattern(v, &node.left); + self.walk_expression(v, &node.right); + } + PatternLike::RestElement(node) => { + self.walk_pattern(v, &node.argument); + } + PatternLike::MemberExpression(node) => { + self.walk_expression(v, &node.object); + if node.computed { + self.walk_expression(v, &node.property); + } + } + } + } + + // ---- Private helper walk methods ---- + + fn walk_for_in_of_left<'ast>(&mut self, v: &mut impl Visitor<'ast>, left: &'ast ForInOfLeft) { + match left { + ForInOfLeft::VariableDeclaration(decl) => self.walk_variable_declaration(v, decl), + ForInOfLeft::Pattern(pat) => self.walk_pattern(v, pat), + } + } + + fn walk_variable_declaration<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + decl: &'ast VariableDeclaration, + ) { + for declarator in &decl.declarations { + v.enter_variable_declarator(declarator, &self.scope_stack); + self.walk_pattern(v, &declarator.id); + if let Some(init) = &declarator.init { + self.walk_expression(v, init); + } + v.leave_variable_declarator(declarator, &self.scope_stack); + } + } + + fn walk_function_declaration_inner<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + node: &'ast FunctionDeclaration, + ) { + let pushed = self.try_push_scope(node.base.start); + v.enter_function_declaration(node, &self.scope_stack); + if v.traverse_function_bodies() { + for param in &node.params { + self.walk_pattern(v, param); + } + self.walk_block_statement(v, &node.body); + } + v.leave_function_declaration(node, &self.scope_stack); + if pushed { + self.scope_stack.pop(); + } + } + + fn walk_object_expression_property<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + prop: &'ast ObjectExpressionProperty, + ) { + match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + if p.computed { + self.walk_expression(v, &p.key); + } + self.walk_expression(v, &p.value); + } + ObjectExpressionProperty::ObjectMethod(node) => { + let pushed = self.try_push_scope(node.base.start); + v.enter_object_method(node, &self.scope_stack); + if v.traverse_function_bodies() { + if node.computed { + self.walk_expression(v, &node.key); + } + for param in &node.params { + self.walk_pattern(v, param); + } + self.walk_block_statement(v, &node.body); + } + v.leave_object_method(node, &self.scope_stack); + if pushed { + self.scope_stack.pop(); + } + } + ObjectExpressionProperty::SpreadElement(p) => { + self.walk_expression(v, &p.argument); + } + } + } + + fn walk_declaration<'ast>(&mut self, v: &mut impl Visitor<'ast>, decl: &'ast Declaration) { + match decl { + Declaration::FunctionDeclaration(node) => { + self.walk_function_declaration_inner(v, node); + } + Declaration::ClassDeclaration(node) => { + if let Some(sc) = &node.super_class { + self.walk_expression(v, sc); + } + } + Declaration::VariableDeclaration(node) => { + self.walk_variable_declaration(v, node); + } + // TS/Flow declarations - no runtime expressions + _ => {} + } + } + + fn walk_export_default_decl<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + decl: &'ast ExportDefaultDecl, + ) { + match decl { + ExportDefaultDecl::FunctionDeclaration(node) => { + self.walk_function_declaration_inner(v, node); + } + ExportDefaultDecl::ClassDeclaration(node) => { + if let Some(sc) = &node.super_class { + self.walk_expression(v, sc); + } + } + ExportDefaultDecl::Expression(expr) => { + self.walk_expression(v, expr); + } + } + } + + fn walk_jsx_element<'ast>(&mut self, v: &mut impl Visitor<'ast>, node: &'ast JSXElement) { + v.enter_jsx_opening_element(&node.opening_element, &self.scope_stack); + self.walk_jsx_element_name(v, &node.opening_element.name); + v.leave_jsx_opening_element(&node.opening_element, &self.scope_stack); + for attr in &node.opening_element.attributes { + match attr { + JSXAttributeItem::JSXAttribute(a) => { + if let Some(value) = &a.value { + match value { + JSXAttributeValue::JSXExpressionContainer(c) => { + self.walk_jsx_expr_container(v, c); + } + JSXAttributeValue::JSXElement(el) => { + self.walk_jsx_element(v, el); + } + JSXAttributeValue::JSXFragment(f) => { + self.walk_jsx_fragment(v, f); + } + JSXAttributeValue::StringLiteral(_) => {} + } + } + } + JSXAttributeItem::JSXSpreadAttribute(a) => { + self.walk_expression(v, &a.argument); + } + } + } + for child in &node.children { + self.walk_jsx_child(v, child); + } + } + + fn walk_jsx_fragment<'ast>(&mut self, v: &mut impl Visitor<'ast>, node: &'ast JSXFragment) { + for child in &node.children { + self.walk_jsx_child(v, child); + } + } + + fn walk_jsx_child<'ast>(&mut self, v: &mut impl Visitor<'ast>, child: &'ast JSXChild) { + match child { + JSXChild::JSXElement(el) => self.walk_jsx_element(v, el), + JSXChild::JSXFragment(f) => self.walk_jsx_fragment(v, f), + JSXChild::JSXExpressionContainer(c) => self.walk_jsx_expr_container(v, c), + JSXChild::JSXSpreadChild(s) => self.walk_expression(v, &s.expression), + JSXChild::JSXText(_) => {} + } + } + + fn walk_jsx_expr_container<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + node: &'ast JSXExpressionContainer, + ) { + match &node.expression { + JSXExpressionContainerExpr::Expression(expr) => self.walk_expression(v, expr), + JSXExpressionContainerExpr::JSXEmptyExpression(_) => {} + } + } + + fn walk_jsx_element_name<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + name: &'ast JSXElementName, + ) { + match name { + JSXElementName::JSXIdentifier(id) => { + v.enter_jsx_identifier(id, &self.scope_stack); + } + JSXElementName::JSXMemberExpression(expr) => { + self.walk_jsx_member_expression(v, expr); + } + JSXElementName::JSXNamespacedName(_) => {} + } + } + + fn walk_jsx_member_expression<'ast>( + &mut self, + v: &mut impl Visitor<'ast>, + expr: &'ast JSXMemberExpression, + ) { + match &*expr.object { + JSXMemberExprObject::JSXIdentifier(id) => { + v.enter_jsx_identifier(id, &self.scope_stack); + } + JSXMemberExprObject::JSXMemberExpression(inner) => { + self.walk_jsx_member_expression(v, inner); + } + } + v.enter_jsx_identifier(&expr.property, &self.scope_stack); + } +} + +// ============================================================================= +// Mutable visitor +// ============================================================================= + +/// Result from a mutable visitor hook. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum VisitResult { + /// Continue traversal to children. + Continue, + /// Stop traversal immediately. + Stop, +} + +impl VisitResult { + pub fn is_stop(self) -> bool { + self == VisitResult::Stop + } +} + +/// Trait for mutating Babel AST nodes during traversal. +/// +/// Override hooks to intercept and mutate specific node types. +/// Return [`VisitResult::Stop`] from any hook to halt the walk. +/// Hooks are called *before* the walker recurses into children, +/// so returning `Stop` prevents child traversal. +pub trait MutVisitor { + /// Called for every statement before recursing into its children. + fn visit_statement(&mut self, _stmt: &mut Statement) -> VisitResult { + VisitResult::Continue + } + + /// Called for every expression before recursing into its children. + fn visit_expression(&mut self, _expr: &mut Expression) -> VisitResult { + VisitResult::Continue + } + + /// Called for identifiers in expression position. + fn visit_identifier(&mut self, _node: &mut Identifier) -> VisitResult { + VisitResult::Continue + } +} + +/// Walk a program's body mutably, calling visitor hooks for each node. +pub fn walk_program_mut(v: &mut impl MutVisitor, program: &mut Program) -> VisitResult { + for stmt in program.body.iter_mut() { + if walk_statement_mut(v, stmt).is_stop() { + return VisitResult::Stop; + } + } + VisitResult::Continue +} + +/// Walk a single statement mutably, calling visitor hooks and recursing into +/// children. +pub fn walk_statement_mut(v: &mut impl MutVisitor, stmt: &mut Statement) -> VisitResult { + if v.visit_statement(stmt).is_stop() { + return VisitResult::Stop; + } + match stmt { + Statement::BlockStatement(node) => { + for s in node.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + Statement::ReturnStatement(node) => { + if let Some(ref mut arg) = node.argument { + if walk_expression_mut(v, arg).is_stop() { + return VisitResult::Stop; + } + } + } + Statement::ExpressionStatement(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Statement::IfStatement(node) => { + if walk_expression_mut(v, &mut node.test).is_stop() { + return VisitResult::Stop; + } + if walk_statement_mut(v, &mut node.consequent).is_stop() { + return VisitResult::Stop; + } + if let Some(ref mut alt) = node.alternate { + if walk_statement_mut(v, alt).is_stop() { + return VisitResult::Stop; + } + } + } + Statement::ForStatement(node) => { + if let Some(ref mut init) = node.init { + match init.as_mut() { + ForInit::VariableDeclaration(decl) => { + if walk_variable_declaration_mut(v, decl).is_stop() { + return VisitResult::Stop; + } + } + ForInit::Expression(expr) => { + if walk_expression_mut(v, expr).is_stop() { + return VisitResult::Stop; + } + } + } + } + if let Some(ref mut test) = node.test { + if walk_expression_mut(v, test).is_stop() { + return VisitResult::Stop; + } + } + if let Some(ref mut update) = node.update { + if walk_expression_mut(v, update).is_stop() { + return VisitResult::Stop; + } + } + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + } + Statement::WhileStatement(node) => { + if walk_expression_mut(v, &mut node.test).is_stop() { + return VisitResult::Stop; + } + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + } + Statement::DoWhileStatement(node) => { + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + if walk_expression_mut(v, &mut node.test).is_stop() { + return VisitResult::Stop; + } + } + Statement::ForInStatement(node) => { + if walk_expression_mut(v, &mut node.right).is_stop() { + return VisitResult::Stop; + } + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + } + Statement::ForOfStatement(node) => { + if walk_expression_mut(v, &mut node.right).is_stop() { + return VisitResult::Stop; + } + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + } + Statement::SwitchStatement(node) => { + if walk_expression_mut(v, &mut node.discriminant).is_stop() { + return VisitResult::Stop; + } + for case in node.cases.iter_mut() { + if let Some(ref mut test) = case.test { + if walk_expression_mut(v, test).is_stop() { + return VisitResult::Stop; + } + } + for s in case.consequent.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + } + Statement::ThrowStatement(node) => { + if walk_expression_mut(v, &mut node.argument).is_stop() { + return VisitResult::Stop; + } + } + Statement::TryStatement(node) => { + for s in node.block.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + if let Some(ref mut handler) = node.handler { + for s in handler.body.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + if let Some(ref mut finalizer) = node.finalizer { + for s in finalizer.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + } + Statement::LabeledStatement(node) => { + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + } + Statement::VariableDeclaration(node) => { + if walk_variable_declaration_mut(v, node).is_stop() { + return VisitResult::Stop; + } + } + Statement::FunctionDeclaration(node) => { + for s in node.body.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + Statement::ClassDeclaration(node) => { + if let Some(ref mut sc) = node.super_class { + if walk_expression_mut(v, sc).is_stop() { + return VisitResult::Stop; + } + } + } + Statement::WithStatement(node) => { + if walk_expression_mut(v, &mut node.object).is_stop() { + return VisitResult::Stop; + } + if walk_statement_mut(v, &mut node.body).is_stop() { + return VisitResult::Stop; + } + } + Statement::ExportNamedDeclaration(node) => { + if let Some(ref mut decl) = node.declaration { + if walk_declaration_mut(v, decl).is_stop() { + return VisitResult::Stop; + } + } + } + Statement::ExportDefaultDeclaration(node) => { + if walk_export_default_decl_mut(v, &mut node.declaration).is_stop() { + return VisitResult::Stop; + } + } + // No runtime expressions to traverse + Statement::BreakStatement(_) + | Statement::ContinueStatement(_) + | Statement::EmptyStatement(_) + | Statement::DebuggerStatement(_) + | Statement::ImportDeclaration(_) + | Statement::ExportAllDeclaration(_) + | Statement::TSTypeAliasDeclaration(_) + | Statement::TSInterfaceDeclaration(_) + | Statement::TSEnumDeclaration(_) + | Statement::TSModuleDeclaration(_) + | Statement::TSDeclareFunction(_) + | Statement::TypeAlias(_) + | Statement::OpaqueType(_) + | Statement::InterfaceDeclaration(_) + | Statement::DeclareVariable(_) + | Statement::DeclareFunction(_) + | Statement::DeclareClass(_) + | Statement::DeclareModule(_) + | Statement::DeclareModuleExports(_) + | Statement::DeclareExportDeclaration(_) + | Statement::DeclareExportAllDeclaration(_) + | Statement::DeclareInterface(_) + | Statement::DeclareTypeAlias(_) + | Statement::DeclareOpaqueType(_) + | Statement::EnumDeclaration(_) => {} + } + VisitResult::Continue +} + +/// Walk an expression mutably, calling visitor hooks and recursing into +/// children. +pub fn walk_expression_mut(v: &mut impl MutVisitor, expr: &mut Expression) -> VisitResult { + if v.visit_expression(expr).is_stop() { + return VisitResult::Stop; + } + match expr { + Expression::Identifier(node) => { + if v.visit_identifier(node).is_stop() { + return VisitResult::Stop; + } + } + Expression::CallExpression(node) => { + if walk_expression_mut(v, &mut node.callee).is_stop() { + return VisitResult::Stop; + } + for arg in node.arguments.iter_mut() { + if walk_expression_mut(v, arg).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::MemberExpression(node) => { + if walk_expression_mut(v, &mut node.object).is_stop() { + return VisitResult::Stop; + } + if node.computed { + if walk_expression_mut(v, &mut node.property).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::OptionalCallExpression(node) => { + if walk_expression_mut(v, &mut node.callee).is_stop() { + return VisitResult::Stop; + } + for arg in node.arguments.iter_mut() { + if walk_expression_mut(v, arg).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::OptionalMemberExpression(node) => { + if walk_expression_mut(v, &mut node.object).is_stop() { + return VisitResult::Stop; + } + if node.computed { + if walk_expression_mut(v, &mut node.property).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::BinaryExpression(node) => { + if walk_expression_mut(v, &mut node.left).is_stop() { + return VisitResult::Stop; + } + if walk_expression_mut(v, &mut node.right).is_stop() { + return VisitResult::Stop; + } + } + Expression::LogicalExpression(node) => { + if walk_expression_mut(v, &mut node.left).is_stop() { + return VisitResult::Stop; + } + if walk_expression_mut(v, &mut node.right).is_stop() { + return VisitResult::Stop; + } + } + Expression::UnaryExpression(node) => { + if walk_expression_mut(v, &mut node.argument).is_stop() { + return VisitResult::Stop; + } + } + Expression::UpdateExpression(node) => { + if walk_expression_mut(v, &mut node.argument).is_stop() { + return VisitResult::Stop; + } + } + Expression::ConditionalExpression(node) => { + if walk_expression_mut(v, &mut node.test).is_stop() { + return VisitResult::Stop; + } + if walk_expression_mut(v, &mut node.consequent).is_stop() { + return VisitResult::Stop; + } + if walk_expression_mut(v, &mut node.alternate).is_stop() { + return VisitResult::Stop; + } + } + Expression::AssignmentExpression(node) => { + if walk_expression_mut(v, &mut node.right).is_stop() { + return VisitResult::Stop; + } + } + Expression::SequenceExpression(node) => { + for e in node.expressions.iter_mut() { + if walk_expression_mut(v, e).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::ArrowFunctionExpression(node) => match node.body.as_mut() { + ArrowFunctionBody::BlockStatement(block) => { + for s in block.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + ArrowFunctionBody::Expression(e) => { + if walk_expression_mut(v, e).is_stop() { + return VisitResult::Stop; + } + } + }, + Expression::FunctionExpression(node) => { + for s in node.body.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::ObjectExpression(node) => { + for prop in node.properties.iter_mut() { + match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + if p.computed { + if walk_expression_mut(v, &mut p.key).is_stop() { + return VisitResult::Stop; + } + } + if walk_expression_mut(v, &mut p.value).is_stop() { + return VisitResult::Stop; + } + } + ObjectExpressionProperty::ObjectMethod(m) => { + for s in m.body.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + ObjectExpressionProperty::SpreadElement(s) => { + if walk_expression_mut(v, &mut s.argument).is_stop() { + return VisitResult::Stop; + } + } + } + } + } + Expression::ArrayExpression(node) => { + for elem in node.elements.iter_mut().flatten() { + if walk_expression_mut(v, elem).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::NewExpression(node) => { + if walk_expression_mut(v, &mut node.callee).is_stop() { + return VisitResult::Stop; + } + for arg in node.arguments.iter_mut() { + if walk_expression_mut(v, arg).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::TemplateLiteral(node) => { + for e in node.expressions.iter_mut() { + if walk_expression_mut(v, e).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::TaggedTemplateExpression(node) => { + if walk_expression_mut(v, &mut node.tag).is_stop() { + return VisitResult::Stop; + } + for e in node.quasi.expressions.iter_mut() { + if walk_expression_mut(v, e).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::AwaitExpression(node) => { + if walk_expression_mut(v, &mut node.argument).is_stop() { + return VisitResult::Stop; + } + } + Expression::YieldExpression(node) => { + if let Some(ref mut arg) = node.argument { + if walk_expression_mut(v, arg).is_stop() { + return VisitResult::Stop; + } + } + } + Expression::SpreadElement(node) => { + if walk_expression_mut(v, &mut node.argument).is_stop() { + return VisitResult::Stop; + } + } + Expression::ParenthesizedExpression(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Expression::AssignmentPattern(node) => { + if walk_expression_mut(v, &mut node.right).is_stop() { + return VisitResult::Stop; + } + } + Expression::ClassExpression(node) => { + if let Some(ref mut sc) = node.super_class { + if walk_expression_mut(v, sc).is_stop() { + return VisitResult::Stop; + } + } + } + // JSX — not walked for current use cases + Expression::JSXElement(_) | Expression::JSXFragment(_) => {} + // TS/Flow wrappers — traverse inner expression + Expression::TSAsExpression(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Expression::TSSatisfiesExpression(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Expression::TSNonNullExpression(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Expression::TSTypeAssertion(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Expression::TSInstantiationExpression(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + Expression::TypeCastExpression(node) => { + if walk_expression_mut(v, &mut node.expression).is_stop() { + return VisitResult::Stop; + } + } + // Leaf nodes + Expression::StringLiteral(_) + | Expression::NumericLiteral(_) + | Expression::BooleanLiteral(_) + | Expression::NullLiteral(_) + | Expression::BigIntLiteral(_) + | Expression::RegExpLiteral(_) + | Expression::MetaProperty(_) + | Expression::PrivateName(_) + | Expression::Super(_) + | Expression::Import(_) + | Expression::ThisExpression(_) => {} + } + VisitResult::Continue +} + +// ---- Private helper walk-mut functions ---- + +fn walk_variable_declaration_mut( + v: &mut impl MutVisitor, + decl: &mut VariableDeclaration, +) -> VisitResult { + for declarator in decl.declarations.iter_mut() { + if let Some(ref mut init) = declarator.init { + if walk_expression_mut(v, init).is_stop() { + return VisitResult::Stop; + } + } + } + VisitResult::Continue +} + +fn walk_declaration_mut(v: &mut impl MutVisitor, decl: &mut Declaration) -> VisitResult { + match decl { + Declaration::FunctionDeclaration(node) => { + for s in node.body.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + Declaration::VariableDeclaration(node) => { + if walk_variable_declaration_mut(v, node).is_stop() { + return VisitResult::Stop; + } + } + Declaration::ClassDeclaration(node) => { + if let Some(ref mut sc) = node.super_class { + if walk_expression_mut(v, sc).is_stop() { + return VisitResult::Stop; + } + } + } + _ => {} + } + VisitResult::Continue +} + +fn walk_export_default_decl_mut( + v: &mut impl MutVisitor, + decl: &mut ExportDefaultDecl, +) -> VisitResult { + match decl { + ExportDefaultDecl::FunctionDeclaration(node) => { + for s in node.body.body.iter_mut() { + if walk_statement_mut(v, s).is_stop() { + return VisitResult::Stop; + } + } + } + ExportDefaultDecl::Expression(expr) => { + if walk_expression_mut(v, expr).is_stop() { + return VisitResult::Stop; + } + } + ExportDefaultDecl::ClassDeclaration(node) => { + if let Some(ref mut sc) = node.super_class { + if walk_expression_mut(v, sc).is_stop() { + return VisitResult::Stop; + } + } + } + } + VisitResult::Continue +} diff --git a/crates/react_compiler_ast/tests/round_trip.rs b/crates/react_compiler_ast/tests/round_trip.rs new file mode 100644 index 000000000000..f9c6ecca7e46 --- /dev/null +++ b/crates/react_compiler_ast/tests/round_trip.rs @@ -0,0 +1,140 @@ +use std::path::PathBuf; + +fn get_fixture_json_dir() -> PathBuf { + if let Ok(dir) = std::env::var("FIXTURE_JSON_DIR") { + return PathBuf::from(dir); + } + // Default: fixtures checked in alongside the test + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures") +} + +/// Recursively sort all keys in a JSON value for order-independent comparison. +fn normalize_json(value: &serde_json::Value) -> serde_json::Value { + match value { + serde_json::Value::Object(map) => { + let mut sorted: Vec<(String, serde_json::Value)> = map + .iter() + .map(|(k, v)| (k.clone(), normalize_json(v))) + .collect(); + sorted.sort_by(|a, b| a.0.cmp(&b.0)); + serde_json::Value::Object(sorted.into_iter().collect()) + } + serde_json::Value::Array(arr) => { + serde_json::Value::Array(arr.iter().map(normalize_json).collect()) + } + // Normalize numbers: f64 values like 1.0 should compare equal to integer 1 + serde_json::Value::Number(n) => { + if let Some(f) = n.as_f64() { + if f.fract() == 0.0 && f.is_finite() && f.abs() < (i64::MAX as f64) { + serde_json::Value::Number(serde_json::Number::from(f as i64)) + } else { + value.clone() + } + } else { + value.clone() + } + } + other => other.clone(), + } +} + +fn compute_diff(original: &str, round_tripped: &str) -> String { + use similar::{ChangeTag, TextDiff}; + + let diff = TextDiff::from_lines(original, round_tripped); + let mut output = String::new(); + let mut lines_written = 0; + const MAX_DIFF_LINES: usize = 50; + + for change in diff.iter_all_changes() { + if lines_written >= MAX_DIFF_LINES { + output.push_str("... (diff truncated)\n"); + break; + } + let sign = match change.tag() { + ChangeTag::Delete => "-", + ChangeTag::Insert => "+", + ChangeTag::Equal => continue, + }; + output.push_str(&format!("{sign} {change}")); + lines_written += 1; + } + + output +} + +#[test] +fn round_trip_all_fixtures() { + let json_dir = get_fixture_json_dir(); + + let mut failures: Vec<(String, String)> = Vec::new(); + let mut total = 0; + let mut passed = 0; + + for entry in walkdir::WalkDir::new(&json_dir) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| { + e.path().extension().is_some_and(|ext| ext == "json") + && !e.path().to_string_lossy().ends_with(".scope.json") + && !e.path().to_string_lossy().ends_with(".renamed.json") + }) + { + let fixture_name = entry + .path() + .strip_prefix(&json_dir) + .unwrap() + .display() + .to_string(); + let original_json = std::fs::read_to_string(entry.path()).unwrap(); + total += 1; + + // Deserialize into our Rust types + let ast: react_compiler_ast::File = match serde_json::from_str(&original_json) { + Ok(ast) => ast, + Err(e) => { + failures.push((fixture_name, format!("Deserialization error: {e}"))); + continue; + } + }; + + // Re-serialize back to JSON + let round_tripped = serde_json::to_string_pretty(&ast).unwrap(); + + // Normalize and compare + let original_value: serde_json::Value = serde_json::from_str(&original_json).unwrap(); + let round_tripped_value: serde_json::Value = serde_json::from_str(&round_tripped).unwrap(); + + let original_normalized = normalize_json(&original_value); + let round_tripped_normalized = normalize_json(&round_tripped_value); + + if original_normalized != round_tripped_normalized { + let orig_str = serde_json::to_string_pretty(&original_normalized).unwrap(); + let rt_str = serde_json::to_string_pretty(&round_tripped_normalized).unwrap(); + let diff = compute_diff(&orig_str, &rt_str); + failures.push((fixture_name, diff)); + } else { + passed += 1; + } + } + + println!("\n{passed}/{total} fixtures passed round-trip"); + + if !failures.is_empty() { + let show_count = failures.len().min(5); + let mut msg = format!( + "\n{} of {total} fixtures failed round-trip (showing first {show_count}):\n\n", + failures.len() + ); + for (name, diff) in failures.iter().take(show_count) { + msg.push_str(&format!("--- {name} ---\n{diff}\n\n")); + } + if failures.len() > show_count { + msg.push_str(&format!( + "... and {} more failures\n", + failures.len() - show_count + )); + } + panic!("{msg}"); + } +} diff --git a/crates/react_compiler_ast/tests/scope_resolution.rs b/crates/react_compiler_ast/tests/scope_resolution.rs new file mode 100644 index 000000000000..d79cd36ca23b --- /dev/null +++ b/crates/react_compiler_ast/tests/scope_resolution.rs @@ -0,0 +1,1046 @@ +use std::path::PathBuf; + +use react_compiler_ast::{ + declarations::*, expressions::*, jsx::*, patterns::*, scope::ScopeInfo, statements::*, +}; + +fn get_fixture_json_dir() -> PathBuf { + if let Ok(dir) = std::env::var("FIXTURE_JSON_DIR") { + return PathBuf::from(dir); + } + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures") +} + +/// Recursively sort all keys in a JSON value for order-independent comparison. +fn normalize_json(value: &serde_json::Value) -> serde_json::Value { + match value { + serde_json::Value::Object(map) => { + let mut sorted: Vec<(String, serde_json::Value)> = map + .iter() + .map(|(k, v)| (k.clone(), normalize_json(v))) + .collect(); + sorted.sort_by(|a, b| a.0.cmp(&b.0)); + serde_json::Value::Object(sorted.into_iter().collect()) + } + serde_json::Value::Array(arr) => { + serde_json::Value::Array(arr.iter().map(normalize_json).collect()) + } + serde_json::Value::Number(n) => { + if let Some(f) = n.as_f64() { + if f.fract() == 0.0 && f.is_finite() && f.abs() < (i64::MAX as f64) { + serde_json::Value::Number(serde_json::Number::from(f as i64)) + } else { + value.clone() + } + } else { + value.clone() + } + } + other => other.clone(), + } +} + +fn compute_diff(original: &str, round_tripped: &str) -> String { + use similar::{ChangeTag, TextDiff}; + let diff = TextDiff::from_lines(original, round_tripped); + let mut output = String::new(); + let mut lines_written = 0; + const MAX_DIFF_LINES: usize = 50; + for change in diff.iter_all_changes() { + if lines_written >= MAX_DIFF_LINES { + output.push_str("... (diff truncated)\n"); + break; + } + let sign = match change.tag() { + ChangeTag::Delete => "-", + ChangeTag::Insert => "+", + ChangeTag::Equal => continue, + }; + output.push_str(&format!("{sign} {change}")); + lines_written += 1; + } + output +} + +#[test] +fn scope_info_round_trip() { + let json_dir = get_fixture_json_dir(); + let mut failures: Vec<(String, String)> = Vec::new(); + let mut total = 0; + let mut passed = 0; + let mut skipped = 0; + + for entry in walkdir::WalkDir::new(&json_dir) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| { + e.path().extension().is_some_and(|ext| ext == "json") + && !e.path().to_string_lossy().contains(".scope.") + && !e.path().to_string_lossy().contains(".renamed.") + }) + { + let ast_path_str = entry.path().to_string_lossy().to_string(); + let scope_path_str = ast_path_str.replace(".json", ".scope.json"); + let scope_path = std::path::Path::new(&scope_path_str); + + if !scope_path.exists() { + skipped += 1; + continue; + } + + let fixture_name = entry + .path() + .strip_prefix(&json_dir) + .unwrap() + .display() + .to_string(); + total += 1; + + let scope_json = std::fs::read_to_string(scope_path).unwrap(); + + let scope_info: react_compiler_ast::scope::ScopeInfo = + match serde_json::from_str(&scope_json) { + Ok(info) => info, + Err(e) => { + failures.push((fixture_name, format!("Scope deserialization error: {e}"))); + continue; + } + }; + + let round_tripped = serde_json::to_string_pretty(&scope_info).unwrap(); + let original_value: serde_json::Value = serde_json::from_str(&scope_json).unwrap(); + let round_tripped_value: serde_json::Value = serde_json::from_str(&round_tripped).unwrap(); + + let original_normalized = normalize_json(&original_value); + let round_tripped_normalized = normalize_json(&round_tripped_value); + + if original_normalized != round_tripped_normalized { + let orig_str = serde_json::to_string_pretty(&original_normalized).unwrap(); + let rt_str = serde_json::to_string_pretty(&round_tripped_normalized).unwrap(); + let diff = compute_diff(&orig_str, &rt_str); + failures.push((fixture_name, format!("Round-trip mismatch:\n{diff}"))); + continue; + } + + let mut consistency_error = None; + + for binding in &scope_info.bindings { + if binding.scope.0 as usize >= scope_info.scopes.len() { + consistency_error = Some(format!( + "Binding {} has scope {} but only {} scopes exist", + binding.name, + binding.scope.0, + scope_info.scopes.len() + )); + break; + } + } + + if consistency_error.is_none() { + for scope in &scope_info.scopes { + for (name, &bid) in &scope.bindings { + if bid.0 as usize >= scope_info.bindings.len() { + consistency_error = Some(format!( + "Scope {} has binding '{}' with id {} but only {} bindings exist", + scope.id.0, + name, + bid.0, + scope_info.bindings.len() + )); + break; + } + } + if consistency_error.is_some() { + break; + } + if let Some(parent) = scope.parent { + if parent.0 as usize >= scope_info.scopes.len() { + consistency_error = Some(format!( + "Scope {} has parent {} but only {} scopes exist", + scope.id.0, + parent.0, + scope_info.scopes.len() + )); + break; + } + } + } + } + + if consistency_error.is_none() { + for (&_offset, &bid) in &scope_info.reference_to_binding { + if bid.0 as usize >= scope_info.bindings.len() { + consistency_error = Some(format!( + "reference_to_binding has binding id {} but only {} bindings exist", + bid.0, + scope_info.bindings.len() + )); + break; + } + } + } + + if consistency_error.is_none() { + for (&_offset, &sid) in &scope_info.node_to_scope { + if sid.0 as usize >= scope_info.scopes.len() { + consistency_error = Some(format!( + "node_to_scope has scope id {} but only {} scopes exist", + sid.0, + scope_info.scopes.len() + )); + break; + } + } + } + + if let Some(err) = consistency_error { + failures.push((fixture_name, format!("Consistency error: {err}"))); + continue; + } + + passed += 1; + } + + println!( + "\n{passed}/{total} fixtures passed scope info round-trip ({skipped} skipped - no \ + scope.json)" + ); + + if !failures.is_empty() { + let show_count = failures.len().min(5); + let mut msg = format!( + "\n{} of {total} fixtures failed scope info test (showing first {show_count}):\n\n", + failures.len() + ); + for (name, err) in failures.iter().take(show_count) { + msg.push_str(&format!("--- {name} ---\n{err}\n\n")); + } + if failures.len() > show_count { + msg.push_str(&format!( + "... and {} more failures\n", + failures.len() - show_count + )); + } + panic!("{msg}"); + } +} + +// ============================================================================ +// Typed AST traversal for identifier renaming +// ============================================================================ + +/// Rename an Identifier if it has a binding in reference_to_binding. +/// Uses the declaring scope from the binding table — no scope stack needed. +fn rename_id(id: &mut Identifier, si: &ScopeInfo) { + if let Some(start) = id.base.start { + if let Some(&bid) = si.reference_to_binding.get(&start) { + let scope = si.bindings[bid.0 as usize].scope.0; + id.name = format!("{}_{}", id.name, format_args!("{scope}_{}", bid.0)); + } + } + visit_json_opt(&mut id.type_annotation, si); + if let Some(decorators) = &mut id.decorators { + visit_json_vec(decorators, si); + } +} + +/// Fallback walker for serde_json::Value fields (class bodies, type +/// annotations, decorators, etc.) +fn visit_json(val: &mut serde_json::Value, si: &ScopeInfo) { + match val { + serde_json::Value::Object(map) => { + if map.get("type").and_then(|v| v.as_str()) == Some("Identifier") { + if let Some(start) = map.get("start").and_then(|v| v.as_u64()) { + if let Some(&bid) = si.reference_to_binding.get(&(start as u32)) { + let scope = si.bindings[bid.0 as usize].scope.0; + if let Some(name) = map + .get("name") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + { + map.insert( + "name".to_string(), + serde_json::Value::String(format!("{name}_{scope}_{}", bid.0)), + ); + } + } + } + } + let keys: Vec = map.keys().cloned().collect(); + for key in keys { + if let Some(child) = map.get_mut(&key) { + visit_json(child, si); + } + } + } + serde_json::Value::Array(arr) => { + for item in arr.iter_mut() { + visit_json(item, si); + } + } + _ => {} + } +} + +fn visit_json_vec(vals: &mut [serde_json::Value], si: &ScopeInfo) { + for val in vals.iter_mut() { + visit_json(val, si); + } +} + +fn visit_json_opt(val: &mut Option>, si: &ScopeInfo) { + if let Some(v) = val { + visit_json(v, si); + } +} + +fn rename_identifiers(file: &mut react_compiler_ast::File, si: &ScopeInfo) { + visit_program(&mut file.program, si); +} + +fn visit_program(prog: &mut react_compiler_ast::Program, si: &ScopeInfo) { + for stmt in &mut prog.body { + visit_stmt(stmt, si); + } +} + +fn visit_block(block: &mut BlockStatement, si: &ScopeInfo) { + for stmt in &mut block.body { + visit_stmt(stmt, si); + } +} + +fn visit_stmt(stmt: &mut Statement, si: &ScopeInfo) { + match stmt { + Statement::BlockStatement(s) => visit_block(s, si), + Statement::ReturnStatement(s) => { + if let Some(arg) = &mut s.argument { + visit_expr(arg, si); + } + } + Statement::ExpressionStatement(s) => visit_expr(&mut s.expression, si), + Statement::IfStatement(s) => { + visit_expr(&mut s.test, si); + visit_stmt(&mut s.consequent, si); + if let Some(alt) = &mut s.alternate { + visit_stmt(alt, si); + } + } + Statement::ForStatement(s) => { + if let Some(init) = &mut s.init { + match init.as_mut() { + ForInit::VariableDeclaration(d) => visit_var_decl(d, si), + ForInit::Expression(e) => visit_expr(e, si), + } + } + if let Some(test) = &mut s.test { + visit_expr(test, si); + } + if let Some(update) = &mut s.update { + visit_expr(update, si); + } + visit_stmt(&mut s.body, si); + } + Statement::WhileStatement(s) => { + visit_expr(&mut s.test, si); + visit_stmt(&mut s.body, si); + } + Statement::DoWhileStatement(s) => { + visit_stmt(&mut s.body, si); + visit_expr(&mut s.test, si); + } + Statement::ForInStatement(s) => { + visit_for_left(&mut s.left, si); + visit_expr(&mut s.right, si); + visit_stmt(&mut s.body, si); + } + Statement::ForOfStatement(s) => { + visit_for_left(&mut s.left, si); + visit_expr(&mut s.right, si); + visit_stmt(&mut s.body, si); + } + Statement::SwitchStatement(s) => { + visit_expr(&mut s.discriminant, si); + for case in &mut s.cases { + if let Some(test) = &mut case.test { + visit_expr(test, si); + } + for child in &mut case.consequent { + visit_stmt(child, si); + } + } + } + Statement::ThrowStatement(s) => visit_expr(&mut s.argument, si), + Statement::TryStatement(s) => { + visit_block(&mut s.block, si); + if let Some(handler) = &mut s.handler { + if let Some(param) = &mut handler.param { + visit_pat(param, si); + } + visit_block(&mut handler.body, si); + } + if let Some(fin) = &mut s.finalizer { + visit_block(fin, si); + } + } + Statement::LabeledStatement(s) => visit_stmt(&mut s.body, si), + Statement::WithStatement(s) => { + visit_expr(&mut s.object, si); + visit_stmt(&mut s.body, si); + } + Statement::VariableDeclaration(d) => visit_var_decl(d, si), + Statement::FunctionDeclaration(f) => visit_func_decl(f, si), + Statement::ClassDeclaration(c) => visit_class_decl(c, si), + Statement::ImportDeclaration(d) => visit_import_decl(d, si), + Statement::ExportNamedDeclaration(d) => visit_export_named(d, si), + Statement::ExportDefaultDeclaration(d) => visit_export_default(d, si), + Statement::TSTypeAliasDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.type_annotation, si); + visit_json_opt(&mut d.type_parameters, si); + } + Statement::TSInterfaceDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + visit_json_opt(&mut d.type_parameters, si); + if let Some(ext) = &mut d.extends { + visit_json_vec(ext, si); + } + } + Statement::TSEnumDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json_vec(&mut d.members, si); + } + Statement::TSModuleDeclaration(d) => { + visit_json(&mut d.id, si); + visit_json(&mut d.body, si); + } + Statement::TSDeclareFunction(d) => { + if let Some(id) = &mut d.id { + rename_id(id, si); + } + visit_json_vec(&mut d.params, si); + visit_json_opt(&mut d.return_type, si); + visit_json_opt(&mut d.type_parameters, si); + } + Statement::TypeAlias(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.right, si); + visit_json_opt(&mut d.type_parameters, si); + } + Statement::OpaqueType(d) => { + rename_id(&mut d.id, si); + if let Some(st) = &mut d.supertype { + visit_json(st, si); + } + visit_json(&mut d.impltype, si); + visit_json_opt(&mut d.type_parameters, si); + } + Statement::InterfaceDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + visit_json_opt(&mut d.type_parameters, si); + if let Some(ext) = &mut d.extends { + visit_json_vec(ext, si); + } + } + Statement::DeclareVariable(d) => rename_id(&mut d.id, si), + Statement::DeclareFunction(d) => { + rename_id(&mut d.id, si); + if let Some(pred) = &mut d.predicate { + visit_json(pred, si); + } + } + Statement::DeclareClass(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + visit_json_opt(&mut d.type_parameters, si); + if let Some(ext) = &mut d.extends { + visit_json_vec(ext, si); + } + } + Statement::DeclareModule(d) => { + visit_json(&mut d.id, si); + visit_json(&mut d.body, si); + } + Statement::DeclareModuleExports(d) => visit_json(&mut d.type_annotation, si), + Statement::DeclareExportDeclaration(d) => { + if let Some(decl) = &mut d.declaration { + visit_json(decl, si); + } + if let Some(specs) = &mut d.specifiers { + visit_json_vec(specs, si); + } + } + Statement::DeclareInterface(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + visit_json_opt(&mut d.type_parameters, si); + if let Some(ext) = &mut d.extends { + visit_json_vec(ext, si); + } + } + Statement::DeclareTypeAlias(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.right, si); + visit_json_opt(&mut d.type_parameters, si); + } + Statement::DeclareOpaqueType(d) => { + rename_id(&mut d.id, si); + if let Some(st) = &mut d.supertype { + visit_json(st, si); + } + if let Some(impl_) = &mut d.impltype { + visit_json(impl_, si); + } + visit_json_opt(&mut d.type_parameters, si); + } + Statement::EnumDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + } + Statement::BreakStatement(_) + | Statement::ContinueStatement(_) + | Statement::EmptyStatement(_) + | Statement::DebuggerStatement(_) + | Statement::ExportAllDeclaration(_) + | Statement::DeclareExportAllDeclaration(_) => {} + } +} + +fn visit_expr(expr: &mut Expression, si: &ScopeInfo) { + match expr { + Expression::Identifier(id) => rename_id(id, si), + Expression::CallExpression(e) => { + visit_expr(&mut e.callee, si); + for arg in &mut e.arguments { + visit_expr(arg, si); + } + visit_json_opt(&mut e.type_parameters, si); + visit_json_opt(&mut e.type_arguments, si); + } + Expression::MemberExpression(e) => { + visit_expr(&mut e.object, si); + visit_expr(&mut e.property, si); + } + Expression::OptionalCallExpression(e) => { + visit_expr(&mut e.callee, si); + for arg in &mut e.arguments { + visit_expr(arg, si); + } + visit_json_opt(&mut e.type_parameters, si); + visit_json_opt(&mut e.type_arguments, si); + } + Expression::OptionalMemberExpression(e) => { + visit_expr(&mut e.object, si); + visit_expr(&mut e.property, si); + } + Expression::BinaryExpression(e) => { + visit_expr(&mut e.left, si); + visit_expr(&mut e.right, si); + } + Expression::LogicalExpression(e) => { + visit_expr(&mut e.left, si); + visit_expr(&mut e.right, si); + } + Expression::UnaryExpression(e) => visit_expr(&mut e.argument, si), + Expression::UpdateExpression(e) => visit_expr(&mut e.argument, si), + Expression::ConditionalExpression(e) => { + visit_expr(&mut e.test, si); + visit_expr(&mut e.consequent, si); + visit_expr(&mut e.alternate, si); + } + Expression::AssignmentExpression(e) => { + visit_pat(&mut e.left, si); + visit_expr(&mut e.right, si); + } + Expression::SequenceExpression(e) => { + for child in &mut e.expressions { + visit_expr(child, si); + } + } + Expression::ArrowFunctionExpression(e) => { + if let Some(id) = &mut e.id { + rename_id(id, si); + } + for param in &mut e.params { + visit_pat(param, si); + } + match e.body.as_mut() { + ArrowFunctionBody::BlockStatement(block) => visit_block(block, si), + ArrowFunctionBody::Expression(expr) => visit_expr(expr, si), + } + visit_json_opt(&mut e.return_type, si); + visit_json_opt(&mut e.type_parameters, si); + visit_json_opt(&mut e.predicate, si); + } + Expression::FunctionExpression(e) => { + if let Some(id) = &mut e.id { + rename_id(id, si); + } + for param in &mut e.params { + visit_pat(param, si); + } + visit_block(&mut e.body, si); + visit_json_opt(&mut e.return_type, si); + visit_json_opt(&mut e.type_parameters, si); + } + Expression::ObjectExpression(e) => { + for prop in &mut e.properties { + match prop { + ObjectExpressionProperty::ObjectProperty(p) => { + visit_expr(&mut p.key, si); + visit_expr(&mut p.value, si); + } + ObjectExpressionProperty::ObjectMethod(m) => { + visit_expr(&mut m.key, si); + for param in &mut m.params { + visit_pat(param, si); + } + visit_block(&mut m.body, si); + visit_json_opt(&mut m.return_type, si); + visit_json_opt(&mut m.type_parameters, si); + } + ObjectExpressionProperty::SpreadElement(s) => visit_expr(&mut s.argument, si), + } + } + } + Expression::ArrayExpression(e) => { + for el in e.elements.iter_mut().flatten() { + visit_expr(el, si); + } + } + Expression::NewExpression(e) => { + visit_expr(&mut e.callee, si); + for arg in &mut e.arguments { + visit_expr(arg, si); + } + visit_json_opt(&mut e.type_parameters, si); + visit_json_opt(&mut e.type_arguments, si); + } + Expression::TemplateLiteral(e) => { + for child in &mut e.expressions { + visit_expr(child, si); + } + } + Expression::TaggedTemplateExpression(e) => { + visit_expr(&mut e.tag, si); + for child in &mut e.quasi.expressions { + visit_expr(child, si); + } + visit_json_opt(&mut e.type_parameters, si); + } + Expression::AwaitExpression(e) => visit_expr(&mut e.argument, si), + Expression::YieldExpression(e) => { + if let Some(arg) = &mut e.argument { + visit_expr(arg, si); + } + } + Expression::SpreadElement(e) => visit_expr(&mut e.argument, si), + Expression::MetaProperty(e) => { + rename_id(&mut e.meta, si); + rename_id(&mut e.property, si); + } + Expression::ClassExpression(e) => { + if let Some(id) = &mut e.id { + rename_id(id, si); + } + if let Some(sc) = &mut e.super_class { + visit_expr(sc, si); + } + visit_json_vec(&mut e.body.body, si); + if let Some(dec) = &mut e.decorators { + visit_json_vec(dec, si); + } + visit_json_opt(&mut e.super_type_parameters, si); + visit_json_opt(&mut e.type_parameters, si); + if let Some(imp) = &mut e.implements { + visit_json_vec(imp, si); + } + } + Expression::PrivateName(e) => rename_id(&mut e.id, si), + Expression::ParenthesizedExpression(e) => visit_expr(&mut e.expression, si), + Expression::AssignmentPattern(p) => { + visit_pat(&mut p.left, si); + visit_expr(&mut p.right, si); + } + Expression::TSAsExpression(e) => { + visit_expr(&mut e.expression, si); + visit_json(&mut e.type_annotation, si); + } + Expression::TSSatisfiesExpression(e) => { + visit_expr(&mut e.expression, si); + visit_json(&mut e.type_annotation, si); + } + Expression::TSNonNullExpression(e) => visit_expr(&mut e.expression, si), + Expression::TSTypeAssertion(e) => { + visit_expr(&mut e.expression, si); + visit_json(&mut e.type_annotation, si); + } + Expression::TSInstantiationExpression(e) => { + visit_expr(&mut e.expression, si); + visit_json(&mut e.type_parameters, si); + } + Expression::TypeCastExpression(e) => { + visit_expr(&mut e.expression, si); + visit_json(&mut e.type_annotation, si); + } + Expression::JSXElement(e) => visit_jsx_element(e, si), + Expression::JSXFragment(f) => { + for child in &mut f.children { + visit_jsx_child(child, si); + } + } + Expression::StringLiteral(_) + | Expression::NumericLiteral(_) + | Expression::BooleanLiteral(_) + | Expression::NullLiteral(_) + | Expression::BigIntLiteral(_) + | Expression::RegExpLiteral(_) + | Expression::Super(_) + | Expression::Import(_) + | Expression::ThisExpression(_) => {} + } +} + +fn visit_pat(pat: &mut PatternLike, si: &ScopeInfo) { + match pat { + PatternLike::Identifier(id) => rename_id(id, si), + PatternLike::ObjectPattern(op) => { + for prop in &mut op.properties { + match prop { + ObjectPatternProperty::ObjectProperty(pp) => { + visit_expr(&mut pp.key, si); + visit_pat(&mut pp.value, si); + } + ObjectPatternProperty::RestElement(r) => { + visit_pat(&mut r.argument, si); + visit_json_opt(&mut r.type_annotation, si); + } + } + } + visit_json_opt(&mut op.type_annotation, si); + } + PatternLike::ArrayPattern(ap) => { + for el in ap.elements.iter_mut().flatten() { + visit_pat(el, si); + } + visit_json_opt(&mut ap.type_annotation, si); + } + PatternLike::AssignmentPattern(ap) => { + visit_pat(&mut ap.left, si); + visit_expr(&mut ap.right, si); + visit_json_opt(&mut ap.type_annotation, si); + } + PatternLike::RestElement(re) => { + visit_pat(&mut re.argument, si); + visit_json_opt(&mut re.type_annotation, si); + } + PatternLike::MemberExpression(e) => { + visit_expr(&mut e.object, si); + visit_expr(&mut e.property, si); + } + } +} + +fn visit_for_left(left: &mut Box, si: &ScopeInfo) { + match left.as_mut() { + ForInOfLeft::VariableDeclaration(d) => visit_var_decl(d, si), + ForInOfLeft::Pattern(p) => visit_pat(p, si), + } +} + +fn visit_var_decl(d: &mut VariableDeclaration, si: &ScopeInfo) { + for decl in &mut d.declarations { + visit_pat(&mut decl.id, si); + if let Some(init) = &mut decl.init { + visit_expr(init, si); + } + } +} + +fn visit_func_decl(f: &mut FunctionDeclaration, si: &ScopeInfo) { + if let Some(id) = &mut f.id { + rename_id(id, si); + } + for param in &mut f.params { + visit_pat(param, si); + } + visit_block(&mut f.body, si); + visit_json_opt(&mut f.return_type, si); + visit_json_opt(&mut f.type_parameters, si); + visit_json_opt(&mut f.predicate, si); +} + +fn visit_class_decl(c: &mut ClassDeclaration, si: &ScopeInfo) { + if let Some(id) = &mut c.id { + rename_id(id, si); + } + if let Some(sc) = &mut c.super_class { + visit_expr(sc, si); + } + visit_json_vec(&mut c.body.body, si); + if let Some(dec) = &mut c.decorators { + visit_json_vec(dec, si); + } + visit_json_opt(&mut c.super_type_parameters, si); + visit_json_opt(&mut c.type_parameters, si); + if let Some(imp) = &mut c.implements { + visit_json_vec(imp, si); + } +} + +fn visit_import_decl(d: &mut ImportDeclaration, si: &ScopeInfo) { + for spec in &mut d.specifiers { + match spec { + ImportSpecifier::ImportSpecifier(s) => { + rename_id(&mut s.local, si); + visit_module_export_name(&mut s.imported, si); + } + ImportSpecifier::ImportDefaultSpecifier(s) => rename_id(&mut s.local, si), + ImportSpecifier::ImportNamespaceSpecifier(s) => rename_id(&mut s.local, si), + } + } +} + +fn visit_export_named(d: &mut ExportNamedDeclaration, si: &ScopeInfo) { + if let Some(decl) = &mut d.declaration { + visit_declaration(decl, si); + } + for spec in &mut d.specifiers { + match spec { + ExportSpecifier::ExportSpecifier(s) => { + visit_module_export_name(&mut s.local, si); + visit_module_export_name(&mut s.exported, si); + } + ExportSpecifier::ExportDefaultSpecifier(s) => rename_id(&mut s.exported, si), + ExportSpecifier::ExportNamespaceSpecifier(s) => { + visit_module_export_name(&mut s.exported, si); + } + } + } +} + +fn visit_export_default(d: &mut ExportDefaultDeclaration, si: &ScopeInfo) { + match d.declaration.as_mut() { + ExportDefaultDecl::FunctionDeclaration(f) => visit_func_decl(f, si), + ExportDefaultDecl::ClassDeclaration(c) => visit_class_decl(c, si), + ExportDefaultDecl::Expression(e) => visit_expr(e, si), + } +} + +fn visit_declaration(d: &mut Declaration, si: &ScopeInfo) { + match d { + Declaration::FunctionDeclaration(f) => visit_func_decl(f, si), + Declaration::ClassDeclaration(c) => visit_class_decl(c, si), + Declaration::VariableDeclaration(v) => visit_var_decl(v, si), + Declaration::TSTypeAliasDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.type_annotation, si); + visit_json_opt(&mut d.type_parameters, si); + } + Declaration::TSInterfaceDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + visit_json_opt(&mut d.type_parameters, si); + if let Some(ext) = &mut d.extends { + visit_json_vec(ext, si); + } + } + Declaration::TSEnumDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json_vec(&mut d.members, si); + } + Declaration::TSModuleDeclaration(d) => { + visit_json(&mut d.id, si); + visit_json(&mut d.body, si); + } + Declaration::TSDeclareFunction(d) => { + if let Some(id) = &mut d.id { + rename_id(id, si); + } + visit_json_vec(&mut d.params, si); + visit_json_opt(&mut d.return_type, si); + visit_json_opt(&mut d.type_parameters, si); + } + Declaration::TypeAlias(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.right, si); + visit_json_opt(&mut d.type_parameters, si); + } + Declaration::OpaqueType(d) => { + rename_id(&mut d.id, si); + if let Some(st) = &mut d.supertype { + visit_json(st, si); + } + visit_json(&mut d.impltype, si); + visit_json_opt(&mut d.type_parameters, si); + } + Declaration::InterfaceDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + visit_json_opt(&mut d.type_parameters, si); + if let Some(ext) = &mut d.extends { + visit_json_vec(ext, si); + } + } + Declaration::EnumDeclaration(d) => { + rename_id(&mut d.id, si); + visit_json(&mut d.body, si); + } + } +} + +fn visit_module_export_name(n: &mut ModuleExportName, si: &ScopeInfo) { + match n { + ModuleExportName::Identifier(id) => rename_id(id, si), + ModuleExportName::StringLiteral(_) => {} + } +} + +fn visit_jsx_element(el: &mut JSXElement, si: &ScopeInfo) { + for attr in &mut el.opening_element.attributes { + match attr { + JSXAttributeItem::JSXAttribute(a) => { + if let Some(val) = &mut a.value { + match val { + JSXAttributeValue::JSXExpressionContainer(c) => { + visit_jsx_expr(&mut c.expression, si); + } + JSXAttributeValue::JSXElement(e) => visit_jsx_element(e, si), + JSXAttributeValue::JSXFragment(f) => { + for child in &mut f.children { + visit_jsx_child(child, si); + } + } + JSXAttributeValue::StringLiteral(_) => {} + } + } + } + JSXAttributeItem::JSXSpreadAttribute(s) => visit_expr(&mut s.argument, si), + } + } + visit_json_opt(&mut el.opening_element.type_parameters, si); + for child in &mut el.children { + visit_jsx_child(child, si); + } +} + +fn visit_jsx_child(child: &mut JSXChild, si: &ScopeInfo) { + match child { + JSXChild::JSXElement(e) => visit_jsx_element(e, si), + JSXChild::JSXFragment(f) => { + for child in &mut f.children { + visit_jsx_child(child, si); + } + } + JSXChild::JSXExpressionContainer(c) => visit_jsx_expr(&mut c.expression, si), + JSXChild::JSXSpreadChild(s) => visit_expr(&mut s.expression, si), + JSXChild::JSXText(_) => {} + } +} + +fn visit_jsx_expr(expr: &mut JSXExpressionContainerExpr, si: &ScopeInfo) { + match expr { + JSXExpressionContainerExpr::Expression(e) => visit_expr(e, si), + JSXExpressionContainerExpr::JSXEmptyExpression(_) => {} + } +} + +#[test] +fn scope_resolution_rename() { + let json_dir = get_fixture_json_dir(); + let mut failures: Vec<(String, String)> = Vec::new(); + let mut total = 0; + let mut passed = 0; + let mut skipped = 0; + + for entry in walkdir::WalkDir::new(&json_dir) + .into_iter() + .filter_map(|e| e.ok()) + .filter(|e| { + e.path().extension().is_some_and(|ext| ext == "json") + && !e.path().to_string_lossy().contains(".scope.") + && !e.path().to_string_lossy().contains(".renamed.") + }) + { + let ast_path_str = entry.path().to_string_lossy().to_string(); + let scope_path_str = ast_path_str.replace(".json", ".scope.json"); + let renamed_path_str = ast_path_str.replace(".json", ".renamed.json"); + let scope_path = std::path::Path::new(&scope_path_str); + let renamed_path = std::path::Path::new(&renamed_path_str); + + if !scope_path.exists() || !renamed_path.exists() { + skipped += 1; + continue; + } + + let fixture_name = entry + .path() + .strip_prefix(&json_dir) + .unwrap() + .display() + .to_string(); + total += 1; + + let ast_json = std::fs::read_to_string(entry.path()).unwrap(); + let scope_json = std::fs::read_to_string(scope_path).unwrap(); + let babel_renamed_json = std::fs::read_to_string(renamed_path).unwrap(); + + let scope_info: react_compiler_ast::scope::ScopeInfo = + match serde_json::from_str(&scope_json) { + Ok(info) => info, + Err(e) => { + failures.push((fixture_name, format!("Scope deserialization error: {e}"))); + continue; + } + }; + + // Deserialize into typed AST, rename using scope info, re-serialize + let mut file: react_compiler_ast::File = match serde_json::from_str(&ast_json) { + Ok(f) => f, + Err(e) => { + failures.push((fixture_name, format!("AST deserialization error: {e}"))); + continue; + } + }; + rename_identifiers(&mut file, &scope_info); + let rust_renamed = serde_json::to_value(&file).unwrap(); + + let babel_renamed_value: serde_json::Value = + serde_json::from_str(&babel_renamed_json).unwrap(); + + let rust_normalized = normalize_json(&rust_renamed); + let babel_normalized = normalize_json(&babel_renamed_value); + + if rust_normalized != babel_normalized { + let rust_str = serde_json::to_string_pretty(&rust_normalized).unwrap(); + let babel_str = serde_json::to_string_pretty(&babel_normalized).unwrap(); + let diff = compute_diff(&babel_str, &rust_str); + failures.push((fixture_name, format!("Rename mismatch:\n{diff}"))); + } else { + passed += 1; + } + } + + println!("\n{passed}/{total} fixtures passed scope resolution rename ({skipped} skipped)"); + + if !failures.is_empty() { + let show_count = failures.len().min(5); + let mut msg = format!( + "\n{} of {total} fixtures failed scope resolution rename (showing first \ + {show_count}):\n\n", + failures.len() + ); + for (name, err) in failures.iter().take(show_count) { + msg.push_str(&format!("--- {name} ---\n{err}\n\n")); + } + if failures.len() > show_count { + msg.push_str(&format!( + "... and {} more failures\n", + failures.len() - show_count + )); + } + panic!("{msg}"); + } +} diff --git a/crates/react_compiler_diagnostics/Cargo.toml b/crates/react_compiler_diagnostics/Cargo.toml new file mode 100644 index 000000000000..3782506aeb27 --- /dev/null +++ b/crates/react_compiler_diagnostics/Cargo.toml @@ -0,0 +1,10 @@ +[package] +description = "Vendored React Compiler diagnostics from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_diagnostics" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +serde = { workspace = true, features = ["derive"] } diff --git a/crates/react_compiler_diagnostics/src/code_frame.rs b/crates/react_compiler_diagnostics/src/code_frame.rs new file mode 100644 index 000000000000..1bfca252f410 --- /dev/null +++ b/crates/react_compiler_diagnostics/src/code_frame.rs @@ -0,0 +1,451 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +use crate::{CompilerDiagnosticDetail, CompilerError, CompilerErrorOrDiagnostic}; + +const CODEFRAME_LINES_ABOVE: u32 = 2; +const CODEFRAME_LINES_BELOW: u32 = 3; +const CODEFRAME_MAX_LINES: u32 = 10; +const CODEFRAME_ABBREVIATED_SOURCE_LINES: usize = 5; + +/// Split source text on newlines, matching Babel's NEWLINE regex: +/// /\r\n|[\n\r\u2028\u2029]/ +fn split_lines(source: &str) -> Vec<&str> { + let mut lines = Vec::new(); + let mut start = 0; + let bytes = source.as_bytes(); + let len = bytes.len(); + let mut i = 0; + while i < len { + let ch = bytes[i]; + if ch == b'\r' { + lines.push(&source[start..i]); + if i + 1 < len && bytes[i + 1] == b'\n' { + i += 2; + } else { + i += 1; + } + start = i; + } else if ch == b'\n' { + lines.push(&source[start..i]); + i += 1; + start = i; + } else { + // Check for Unicode line separators U+2028 and U+2029 + // These are encoded as E2 80 A8 and E2 80 A9 in UTF-8 + if ch == 0xe2 + && i + 2 < len + && bytes[i + 1] == 0x80 + && (bytes[i + 2] == 0xa8 || bytes[i + 2] == 0xa9) + { + lines.push(&source[start..i]); + i += 3; + start = i; + } else { + i += 1; + } + } + } + lines.push(&source[start..]); + lines +} + +/// Represents a marker line entry: either mark the whole line (true) or a +/// [column, length] range. +#[derive(Clone, Debug)] +enum MarkerEntry { + WholeLine, + Range(usize, usize), // (start_column_1based, length) +} + +/// Compute marker lines matching Babel's getMarkerLines(). +/// All column values here are 1-based (Babel convention). +fn get_marker_lines( + start_line: u32, + start_column: u32, // 1-based + end_line: u32, + end_column: u32, // 1-based + source_line_count: usize, + lines_above: u32, + lines_below: u32, +) -> (usize, usize, Vec<(usize, MarkerEntry)>) { + let start_line = start_line as usize; + let end_line = end_line as usize; + let start_column = start_column as usize; + let end_column = end_column as usize; + + // Compute display range + let start = start_line.saturating_sub(lines_above as usize + 1); + let end = std::cmp::min(source_line_count, end_line + lines_below as usize); + + let line_diff = end_line - start_line; + let mut marker_lines: Vec<(usize, MarkerEntry)> = Vec::new(); + + if line_diff > 0 { + // Multi-line error + for i in 0..=line_diff { + let line_number = i + start_line; + if start_column == 0 { + marker_lines.push((line_number, MarkerEntry::WholeLine)); + } else if i == 0 { + // First line: from start_column to end of source line + // source[lineNumber - 1] gives us the source line (0-indexed array, 1-indexed + // line numbers) But we don't have access to source lines here, + // so we pass the length through. Actually, Babel accesses + // source[lineNumber - 1].length. We need to thread source lines. + // For now, this is handled in code_frame_columns where we have access to source + // lines. We use a placeholder that will be filled in later. + marker_lines.push((line_number, MarkerEntry::Range(start_column, 0))); + // 0 = placeholder + } else if i == line_diff { + marker_lines.push((line_number, MarkerEntry::Range(0, end_column))); + } else { + marker_lines.push((line_number, MarkerEntry::Range(0, 0))); // 0 + // = + // placeholder + // for + // full + // line + } + } + } else { + // Single-line error + if start_column == end_column { + if start_column != 0 { + marker_lines.push((start_line, MarkerEntry::Range(start_column, 0))); + } else { + marker_lines.push((start_line, MarkerEntry::WholeLine)); + } + } else { + marker_lines.push(( + start_line, + MarkerEntry::Range(start_column, end_column - start_column), + )); + } + } + + (start, end, marker_lines) +} + +/// Produce a code frame matching @babel/code-frame's codeFrameColumns() in +/// non-highlighted mode. +/// +/// Columns are 0-based (matching the Rust/AST convention). They are converted +/// to 1-based internally to match Babel's convention (the JS caller already +/// does column + 1). +pub fn code_frame_columns( + source: &str, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + message: &str, +) -> String { + // Convert 0-based columns to 1-based (Babel convention) + let start_column_1 = start_col + 1; + let end_column_1 = end_col + 1; + + let lines = split_lines(source); + let source_line_count = lines.len(); + + let (start, end, marker_lines_raw) = get_marker_lines( + start_line, + start_column_1, + end_line, + end_column_1, + source_line_count, + CODEFRAME_LINES_ABOVE, + CODEFRAME_LINES_BELOW, + ); + + let has_columns = start_column_1 > 0; + let number_max_width = format!("{end}").len(); + + // Build a lookup map for marker lines + let mut marker_map: std::collections::HashMap = + std::collections::HashMap::new(); + let line_diff = end_line as usize - start_line as usize; + for (line_number, entry) in marker_lines_raw { + // Resolve placeholder lengths using actual source lines + let resolved = match &entry { + MarkerEntry::Range(col, len) => { + if line_diff > 0 { + let i = line_number - start_line as usize; + if i == 0 && *len == 0 { + // First line of multi-line: from start_column to end of line + let source_length = if line_number >= 1 && line_number <= lines.len() { + lines[line_number - 1].len() + } else { + 0 + }; + MarkerEntry::Range(*col, source_length.saturating_sub(*col) + 1) + } else if i > 0 && i < line_diff && *col == 0 && *len == 0 { + // Middle line of multi-line: Babel uses source[lineNumber - i].length + // which evaluates to source[startLine] (0-indexed array, 1-indexed line + // number). This means all middle lines use the + // length of source[startLine], which is the line at + // 0-indexed position startLine in the source array. + let source_length = if (start_line as usize) < lines.len() { + lines[start_line as usize].len() + } else { + 0 + }; + MarkerEntry::Range(0, source_length) + } else { + entry + } + } else { + entry + } + } + _ => entry, + }; + marker_map.insert(line_number, resolved); + } + + // Build frame lines + let mut frame_parts: Vec = Vec::new(); + let display_lines = &lines[start..end]; + + for (index, line) in display_lines.iter().enumerate() { + let number = start + 1 + index; + // Right-align the line number: ` ${number}`.slice(-numberMaxWidth) + let number_str = format!("{number}"); + let padded_number = if number_str.len() >= number_max_width { + number_str + } else { + let padding = " ".repeat(number_max_width - number_str.len()); + format!("{padding}{number_str}") + }; + let gutter = format!(" {padded_number} |"); + + let has_marker = marker_map.get(&number); + let has_next_marker = marker_map.contains_key(&(number + 1)); + let last_marker_line = has_marker.is_some() && !has_next_marker; + + if let Some(marker_entry) = has_marker { + // This is a marked line + let line_content = if line.is_empty() { + String::new() + } else { + format!(" {line}") + }; + + let marker_line_str = match marker_entry { + MarkerEntry::Range(col, len) => { + // Build marker spacing: replace non-tab chars with spaces + let max_col = if *col > 0 { col - 1 } else { 0 }; + let byte_end = std::cmp::min(max_col, line.len()); + // Ensure we don't slice in the middle of a multi-byte UTF-8 character + let safe_end = if byte_end < line.len() && !line.is_char_boundary(byte_end) { + let mut safe_end = byte_end; + while safe_end > 0 && !line.is_char_boundary(safe_end) { + safe_end -= 1; + } + safe_end + } else { + byte_end + }; + let prefix = &line[..safe_end]; + let marker_spacing: String = prefix + .chars() + .map(|c| if c == '\t' { '\t' } else { ' ' }) + .collect(); + let number_of_markers = if *len == 0 { 1 } else { *len }; + let carets = "^".repeat(number_of_markers); + let gutter_spaces = gutter.replace(|c: char| c.is_ascii_digit(), " "); + let mut marker_str = format!("\n {gutter_spaces} {marker_spacing}{carets}"); + if last_marker_line && !message.is_empty() { + marker_str.push(' '); + marker_str.push_str(message); + } + marker_str + } + MarkerEntry::WholeLine => String::new(), + }; + + frame_parts.push(format!(">{gutter}{line_content}{marker_line_str}")); + } else { + // Non-marked line + let line_content = if line.is_empty() { + String::new() + } else { + format!(" {line}") + }; + frame_parts.push(format!(" {gutter}{line_content}")); + } + } + + let mut frame = frame_parts.join("\n"); + + // If message is set but no columns, prepend the message + if !message.is_empty() && !has_columns { + frame = format!("{}{}\n{}", " ".repeat(number_max_width + 1), message, frame); + } + + frame +} + +/// Format a code frame with abbreviation for long spans, +/// matching the JS printCodeFrame() function. +pub fn print_code_frame( + source: &str, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + message: &str, +) -> String { + let printed = code_frame_columns(source, start_line, start_col, end_line, end_col, message); + + if end_line - start_line < CODEFRAME_MAX_LINES { + return printed; + } + + // Abbreviate: truncate middle + let lines: Vec<&str> = printed.split('\n').collect(); + let head_count = CODEFRAME_LINES_ABOVE as usize + CODEFRAME_ABBREVIATED_SOURCE_LINES; + let tail_count = CODEFRAME_LINES_BELOW as usize + CODEFRAME_ABBREVIATED_SOURCE_LINES; + + if lines.len() <= head_count + tail_count { + return printed; + } + + // Find the pipe index from the first line + let pipe_index = lines[0].find('|').unwrap_or(0); + let tail_start = lines.len() - tail_count; + + let mut parts: Vec = Vec::new(); + for line in &lines[..head_count] { + parts.push(line.to_string()); + } + parts.push(format!("{}\u{2026}", " ".repeat(pipe_index))); + for line in &lines[tail_start..] { + parts.push(line.to_string()); + } + parts.join("\n") +} + +use crate::format_category_heading; + +/// Format a CompilerError into a message string matching the TS compiler's +/// CompilerError.printErrorMessage() / formatCompilerError() format. +/// +/// The source parameter is the full source code of the file being compiled. +/// The filename parameter is the source filename (e.g., "foo.ts") used in +/// location displays. +pub fn format_compiler_error(err: &CompilerError, source: &str, filename: Option<&str>) -> String { + let detail_messages: Vec = err + .details + .iter() + .map(|d| format_error_detail(d, source, filename)) + .collect(); + + let count = err.details.len(); + let plural = if count == 1 { "" } else { "s" }; + let header = format!("Found {count} error{plural}:\n\n"); + + let trimmed: Vec = detail_messages + .iter() + .map(|m| m.trim().to_string()) + .collect(); + format!("{}{}", header, trimmed.join("\n\n")) +} + +/// Format a single error detail (either Diagnostic or ErrorDetail). +fn format_error_detail( + detail: &CompilerErrorOrDiagnostic, + source: &str, + filename: Option<&str>, +) -> String { + match detail { + CompilerErrorOrDiagnostic::Diagnostic(d) => { + let heading = format_category_heading(d.category); + let mut buffer = vec![format!("{}: {}", heading, d.reason)]; + + if let Some(ref description) = d.description { + buffer.push(format!("\n\n{description}.")); + } + for item in &d.details { + match item { + CompilerDiagnosticDetail::Error { loc, message, .. } => { + if let Some(loc) = loc { + let frame = print_code_frame( + source, + loc.start.line, + loc.start.column, + loc.end.line, + loc.end.column, + message.as_deref().unwrap_or(""), + ); + buffer.push("\n\n".to_string()); + if let Some(fname) = filename { + buffer.push(format!( + "{}:{}:{}\n", + fname, loc.start.line, loc.start.column + )); + } + buffer.push(frame); + } + } + CompilerDiagnosticDetail::Hint { message } => { + buffer.push("\n\n".to_string()); + buffer.push(message.clone()); + } + } + } + + buffer.join("") + } + CompilerErrorOrDiagnostic::ErrorDetail(d) => { + let heading = format_category_heading(d.category); + let mut buffer = vec![format!("{}: {}", heading, d.reason)]; + + if let Some(ref description) = d.description { + buffer.push(format!("\n\n{description}.")); + if let Some(ref loc) = d.loc { + let frame = print_code_frame( + source, + loc.start.line, + loc.start.column, + loc.end.line, + loc.end.column, + &d.reason, + ); + buffer.push("\n\n".to_string()); + if let Some(fname) = filename { + buffer.push(format!( + "{}:{}:{}\n", + fname, loc.start.line, loc.start.column + )); + } + buffer.push(frame); + buffer.push("\n\n".to_string()); + } + } else if let Some(ref loc) = d.loc { + let frame = print_code_frame( + source, + loc.start.line, + loc.start.column, + loc.end.line, + loc.end.column, + &d.reason, + ); + buffer.push("\n\n".to_string()); + if let Some(fname) = filename { + buffer.push(format!( + "{}:{}:{}\n", + fname, loc.start.line, loc.start.column + )); + } + buffer.push(frame); + buffer.push("\n\n".to_string()); + } + + buffer.join("") + } + } +} diff --git a/crates/react_compiler_diagnostics/src/lib.rs b/crates/react_compiler_diagnostics/src/lib.rs new file mode 100644 index 000000000000..c45935003dd1 --- /dev/null +++ b/crates/react_compiler_diagnostics/src/lib.rs @@ -0,0 +1,462 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod code_frame; + +use serde::{Deserialize, Serialize}; + +/// Error categories matching the TS ErrorCategory enum +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ErrorCategory { + Hooks, + CapitalizedCalls, + StaticComponents, + UseMemo, + VoidUseMemo, + PreserveManualMemo, + MemoDependencies, + IncompatibleLibrary, + Immutability, + Globals, + Refs, + EffectDependencies, + EffectExhaustiveDependencies, + EffectSetState, + EffectDerivationsOfState, + ErrorBoundaries, + Purity, + RenderSetState, + Invariant, + Todo, + Syntax, + UnsupportedSyntax, + Config, + Gating, + Suppression, + FBT, +} + +/// Error severity levels +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ErrorSeverity { + Error, + Warning, + Hint, + Off, +} + +impl ErrorCategory { + pub fn severity(&self) -> ErrorSeverity { + match self { + // These map to "Compilation Skipped" (Warning severity) + ErrorCategory::EffectDependencies + | ErrorCategory::IncompatibleLibrary + | ErrorCategory::PreserveManualMemo + | ErrorCategory::UnsupportedSyntax => ErrorSeverity::Warning, + + // Todo is Hint + ErrorCategory::Todo => ErrorSeverity::Hint, + + // Invariant and all others are Error severity + _ => ErrorSeverity::Error, + } + } + + /// The severity to use in logged output, matching the TS compiler's + /// `getRuleForCategory()`. This may differ from the internal `severity()` + /// used for panicThreshold logic. In particular, `PreserveManualMemo` is + /// `Warning` internally (so it doesn't trigger panicThreshold throws) but + /// `Error` in logged output (matching TS behavior). + pub fn logged_severity(&self) -> ErrorSeverity { + match self { + ErrorCategory::PreserveManualMemo => ErrorSeverity::Error, + _ => self.severity(), + } + } +} + +/// Suggestion operations for auto-fixes +#[derive(Debug, Clone, Serialize)] +pub enum CompilerSuggestionOperation { + InsertBefore, + InsertAfter, + Remove, + Replace, +} + +/// A compiler suggestion for fixing an error +#[derive(Debug, Clone, Serialize)] +pub struct CompilerSuggestion { + pub op: CompilerSuggestionOperation, + pub range: (usize, usize), + pub description: String, + pub text: Option, // None for Remove operations +} + +/// Source location (matches Babel's SourceLocation format) +/// This is the HIR source location, separate from AST's BaseNode location. +/// GeneratedSource is represented as None. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct SourceLocation { + pub start: Position, + pub end: Position, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct Position { + pub line: u32, + pub column: u32, + /// Byte offset in the source file. Preserved for logger event + /// serialization. + #[serde(default, skip_serializing)] + pub index: Option, +} + +/// Sentinel value for generated/synthetic source locations +pub const GENERATED_SOURCE: Option = None; + +/// Detail for a diagnostic +#[derive(Debug, Clone, Serialize)] +pub enum CompilerDiagnosticDetail { + Error { + loc: Option, + message: Option, + /// The identifier name from the AST source location, if this error + /// points to an identifier node. Preserved for logger event + /// serialization to match Babel's SourceLocation.identifierName + /// field. + #[serde(skip)] + identifier_name: Option, + }, + Hint { + message: String, + }, +} + +/// A single compiler diagnostic (new-style) +#[derive(Debug, Clone)] +pub struct CompilerDiagnostic { + pub category: ErrorCategory, + pub reason: String, + pub description: Option, + pub details: Vec, + pub suggestions: Option>, +} + +impl CompilerDiagnostic { + pub fn new( + category: ErrorCategory, + reason: impl Into, + description: Option, + ) -> Self { + Self { + category, + reason: reason.into(), + description, + details: Vec::new(), + suggestions: None, + } + } + + pub fn severity(&self) -> ErrorSeverity { + self.category.severity() + } + + pub fn logged_severity(&self) -> ErrorSeverity { + self.category.logged_severity() + } + + pub fn with_detail(mut self, detail: CompilerDiagnosticDetail) -> Self { + self.details.push(detail); + self + } + + /// Create a Todo diagnostic (matches TS `CompilerError.throwTodo()`). + pub fn todo(reason: impl Into, loc: Option) -> Self { + let reason = reason.into(); + let mut diag = Self::new(ErrorCategory::Todo, reason.clone(), None); + diag.details.push(CompilerDiagnosticDetail::Error { + loc, + message: Some(reason), + identifier_name: None, + }); + diag + } + + /// Create a diagnostic from a CompilerErrorDetail. + pub fn from_detail(detail: CompilerErrorDetail) -> Self { + Self::new( + detail.category, + detail.reason.clone(), + detail.description.clone(), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: detail.loc, + message: Some(detail.reason), + identifier_name: None, + }) + } + + pub fn primary_location(&self) -> Option<&SourceLocation> { + self.details.iter().find_map(|d| match d { + CompilerDiagnosticDetail::Error { loc, .. } => loc.as_ref(), /* identifier_name */ + // covered by .. + _ => None, + }) + } +} + +/// Legacy-style error detail (matches CompilerErrorDetail in TS) +#[derive(Debug, Clone, Serialize)] +pub struct CompilerErrorDetail { + pub category: ErrorCategory, + pub reason: String, + pub description: Option, + pub loc: Option, + pub suggestions: Option>, +} + +impl CompilerErrorDetail { + pub fn new(category: ErrorCategory, reason: impl Into) -> Self { + Self { + category, + reason: reason.into(), + description: None, + loc: None, + suggestions: None, + } + } + + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + pub fn with_loc(mut self, loc: Option) -> Self { + self.loc = loc; + self + } + + pub fn severity(&self) -> ErrorSeverity { + self.category.severity() + } + + pub fn logged_severity(&self) -> ErrorSeverity { + self.category.logged_severity() + } +} + +/// Aggregate compiler error - can contain multiple diagnostics. +/// This is the main error type thrown/returned by the compiler. +#[derive(Debug, Clone)] +pub struct CompilerError { + pub details: Vec, + /// When false, this error was accumulated on the Environment via + /// `record_error()` / `record_diagnostic()` and returned at the end + /// of the pipeline. In TS, `CompileUnexpectedThrow` is only emitted + /// for errors that are **thrown** (not accumulated). Defaults to `true` + /// because errors created directly (e.g., via `?` from a pass) are + /// analogous to thrown errors in the TS code. + pub is_thrown: bool, +} + +/// Either a new-style diagnostic or legacy error detail +#[derive(Debug, Clone)] +pub enum CompilerErrorOrDiagnostic { + Diagnostic(CompilerDiagnostic), + ErrorDetail(CompilerErrorDetail), +} + +impl CompilerErrorOrDiagnostic { + pub fn severity(&self) -> ErrorSeverity { + match self { + Self::Diagnostic(d) => d.severity(), + Self::ErrorDetail(d) => d.severity(), + } + } + + pub fn logged_severity(&self) -> ErrorSeverity { + match self { + Self::Diagnostic(d) => d.logged_severity(), + Self::ErrorDetail(d) => d.logged_severity(), + } + } +} + +impl CompilerError { + pub fn new() -> Self { + Self { + details: Vec::new(), + is_thrown: true, + } + } + + pub fn push_diagnostic(&mut self, diagnostic: CompilerDiagnostic) { + if diagnostic.severity() != ErrorSeverity::Off { + self.details + .push(CompilerErrorOrDiagnostic::Diagnostic(diagnostic)); + } + } + + pub fn push_error_detail(&mut self, detail: CompilerErrorDetail) { + if detail.severity() != ErrorSeverity::Off { + self.details + .push(CompilerErrorOrDiagnostic::ErrorDetail(detail)); + } + } + + pub fn has_errors(&self) -> bool { + self.details + .iter() + .any(|d| d.severity() == ErrorSeverity::Error) + } + + pub fn has_any_errors(&self) -> bool { + !self.details.is_empty() + } + + /// Check if any error detail has Invariant category. + pub fn has_invariant_errors(&self) -> bool { + self.details.iter().any(|d| { + let cat = match d { + CompilerErrorOrDiagnostic::Diagnostic(d) => d.category, + CompilerErrorOrDiagnostic::ErrorDetail(d) => d.category, + }; + cat == ErrorCategory::Invariant + }) + } + + pub fn merge(&mut self, other: CompilerError) { + self.details.extend(other.details); + } + + /// Check if all error details are non-invariant. + /// In TS, this is used to determine if an error thrown during compilation + /// should be logged as CompileUnexpectedThrow. + pub fn is_all_non_invariant(&self) -> bool { + self.details.iter().all(|d| { + let cat = match d { + CompilerErrorOrDiagnostic::Diagnostic(d) => d.category, + CompilerErrorOrDiagnostic::ErrorDetail(d) => d.category, + }; + cat != ErrorCategory::Invariant + }) + } + + /// Format as a string matching the TS `CompilerError.toString()` output. + /// Used for the `data` field of `CompileUnexpectedThrow` events. + /// + /// Format per detail: `"Category: reason. Description. (line:column)"` + /// Multiple details are joined with `"\n\n"`. + pub fn to_string_for_event(&self) -> String { + self.details + .iter() + .map(|d| { + let (category, reason, description, loc) = match d { + CompilerErrorOrDiagnostic::Diagnostic(d) => { + let loc = d.primary_location().cloned(); + (d.category, &d.reason, &d.description, loc) + } + CompilerErrorOrDiagnostic::ErrorDetail(d) => { + (d.category, &d.reason, &d.description, d.loc) + } + }; + let mut buf = format!("{}: {}", format_category_heading(category), reason); + if let Some(desc) = description { + buf.push_str(&format!(". {desc}.")); + } + if let Some(loc) = loc { + buf.push_str(&format!(" ({}:{})", loc.start.line, loc.start.column)); + } + buf + }) + .collect::>() + .join("\n\n") + } +} + +impl Default for CompilerError { + fn default() -> Self { + Self::new() + } +} + +/// Allow `?` to convert a `CompilerError` into a `CompilerDiagnostic` +/// when the enclosing function returns `Result`. +/// +/// This typically happens when `record_error()` returns `Err(CompilerError)` +/// for an Invariant error, and the calling function already returns +/// `Result`. The conversion extracts the first +/// error detail from the aggregate error. +impl From for CompilerDiagnostic { + fn from(err: CompilerError) -> Self { + if let Some(first) = err.details.into_iter().next() { + match first { + CompilerErrorOrDiagnostic::Diagnostic(d) => d, + CompilerErrorOrDiagnostic::ErrorDetail(d) => CompilerDiagnostic::from_detail(d), + } + } else { + CompilerDiagnostic::new(ErrorCategory::Invariant, "Unknown compiler error", None) + } + } +} + +impl From for CompilerError { + fn from(diagnostic: CompilerDiagnostic) -> Self { + let mut error = CompilerError::new(); + // Todo diagnostics should produce ErrorDetail (flat loc format), matching + // the TS behavior where CompilerError.throwTodo() creates a CompilerErrorDetail + // with loc directly on it, not a CompilerDiagnostic with sub-details. + if diagnostic.category == ErrorCategory::Todo { + let loc = diagnostic.primary_location().cloned(); + error.push_error_detail(CompilerErrorDetail { + category: diagnostic.category, + reason: diagnostic.reason, + description: diagnostic.description, + loc, + suggestions: diagnostic.suggestions, + }); + } else { + error.push_diagnostic(diagnostic); + } + error + } +} + +impl std::fmt::Display for CompilerError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for detail in &self.details { + match detail { + CompilerErrorOrDiagnostic::Diagnostic(d) => { + write!(f, "{}: {}", format_category_heading(d.category), d.reason)?; + if let Some(desc) = &d.description { + write!(f, ". {desc}.")?; + } + } + CompilerErrorOrDiagnostic::ErrorDetail(d) => { + write!(f, "{}: {}", format_category_heading(d.category), d.reason)?; + if let Some(desc) = &d.description { + write!(f, ". {desc}.")?; + } + } + } + writeln!(f)?; + } + Ok(()) + } +} + +impl std::error::Error for CompilerError {} + +pub fn format_category_heading(category: ErrorCategory) -> &'static str { + match category { + ErrorCategory::EffectDependencies + | ErrorCategory::IncompatibleLibrary + | ErrorCategory::PreserveManualMemo + | ErrorCategory::UnsupportedSyntax => "Compilation Skipped", + ErrorCategory::Invariant => "Invariant", + ErrorCategory::Todo => "Todo", + _ => "Error", + } +} diff --git a/crates/react_compiler_hir/Cargo.toml b/crates/react_compiler_hir/Cargo.toml new file mode 100644 index 000000000000..400229aa56b0 --- /dev/null +++ b/crates/react_compiler_hir/Cargo.toml @@ -0,0 +1,13 @@ +[package] +description = "Vendored React Compiler HIR from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_hir" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +indexmap = { workspace = true, features = ["serde"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } diff --git a/crates/react_compiler_hir/src/default_module_type_provider.rs b/crates/react_compiler_hir/src/default_module_type_provider.rs new file mode 100644 index 000000000000..27226717d1ce --- /dev/null +++ b/crates/react_compiler_hir/src/default_module_type_provider.rs @@ -0,0 +1,109 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Default module type provider, ported from DefaultModuleTypeProvider.ts. +//! +//! Provides hardcoded type overrides for known-incompatible third-party +//! libraries. + +use indexmap::IndexMap; + +use crate::{ + type_config::{ + BuiltInTypeRef, FunctionTypeConfig, HookTypeConfig, ObjectTypeConfig, TypeConfig, + TypeReferenceConfig, ValueKind, + }, + Effect, +}; + +/// Returns type configuration for known third-party modules that are +/// incompatible with memoization. Ported from TS `defaultModuleTypeProvider`. +pub fn default_module_type_provider(module_name: &str) -> Option { + match module_name { + "react-hook-form" => Some(TypeConfig::Object(ObjectTypeConfig { + properties: Some(IndexMap::from([( + "useForm".to_string(), + TypeConfig::Hook(HookTypeConfig { + return_type: Box::new(TypeConfig::Object(ObjectTypeConfig { + properties: Some(IndexMap::from([( + "watch".to_string(), + TypeConfig::Function(FunctionTypeConfig { + positional_params: Vec::new(), + rest_param: Some(Effect::Read), + callee_effect: Effect::Read, + return_type: Box::new(TypeConfig::TypeReference( + TypeReferenceConfig { + name: BuiltInTypeRef::Any, + }, + )), + return_value_kind: ValueKind::Mutable, + no_alias: None, + mutable_only_if_operands_are_mutable: None, + impure: None, + canonical_name: None, + aliasing: None, + known_incompatible: Some( + "React Hook Form's `useForm()` API returns a `watch()` \ + function which cannot be memoized safely." + .to_string(), + ), + }), + )])), + })), + positional_params: None, + rest_param: None, + return_value_kind: None, + no_alias: None, + aliasing: None, + known_incompatible: None, + }), + )])), + })), + + "@tanstack/react-table" => Some(TypeConfig::Object(ObjectTypeConfig { + properties: Some(IndexMap::from([( + "useReactTable".to_string(), + TypeConfig::Hook(HookTypeConfig { + positional_params: Some(Vec::new()), + rest_param: Some(Effect::Read), + return_type: Box::new(TypeConfig::TypeReference(TypeReferenceConfig { + name: BuiltInTypeRef::Any, + })), + return_value_kind: None, + no_alias: None, + aliasing: None, + known_incompatible: Some( + "TanStack Table's `useReactTable()` API returns functions that cannot be \ + memoized safely" + .to_string(), + ), + }), + )])), + })), + + "@tanstack/react-virtual" => Some(TypeConfig::Object(ObjectTypeConfig { + properties: Some(IndexMap::from([( + "useVirtualizer".to_string(), + TypeConfig::Hook(HookTypeConfig { + positional_params: Some(Vec::new()), + rest_param: Some(Effect::Read), + return_type: Box::new(TypeConfig::TypeReference(TypeReferenceConfig { + name: BuiltInTypeRef::Any, + })), + return_value_kind: None, + no_alias: None, + aliasing: None, + known_incompatible: Some( + "TanStack Virtual's `useVirtualizer()` API returns functions that cannot \ + be memoized safely" + .to_string(), + ), + }), + )])), + })), + + _ => None, + } +} diff --git a/crates/react_compiler_hir/src/dominator.rs b/crates/react_compiler_hir/src/dominator.rs new file mode 100644 index 000000000000..682cda8aefdb --- /dev/null +++ b/crates/react_compiler_hir/src/dominator.rs @@ -0,0 +1,362 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Dominator and post-dominator tree computation. +//! +//! Port of Dominator.ts and ComputeUnconditionalBlocks.ts. +//! Uses the Cooper/Harvey/Kennedy algorithm from +//! https://www.cs.rice.edu/~keith/Embed/dom.pdf + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; + +use crate::{visitors::each_terminal_successor, BlockId, HirFunction, Terminal}; + +// ============================================================================= +// Public types +// ============================================================================= + +/// Stores the immediate post-dominator for each block. +pub struct PostDominator { + /// The exit node (synthetic node representing function exit). + pub exit: BlockId, + nodes: HashMap, +} + +impl PostDominator { + /// Returns the immediate post-dominator of the given block, or None if + /// the block post-dominates itself (i.e., it is the exit node). + pub fn get(&self, id: BlockId) -> Option { + let dominator = self + .nodes + .get(&id) + .expect("Unknown node in post-dominator tree"); + if *dominator == id { + None + } else { + Some(*dominator) + } + } +} + +// ============================================================================= +// Graph representation +// ============================================================================= + +struct Node { + id: BlockId, + index: usize, + preds: HashSet, + succs: HashSet, +} + +struct Graph { + entry: BlockId, + /// Nodes stored in iteration order (RPO for reverse graph). + nodes: Vec, + /// Map from BlockId to index in the nodes vec. + node_index: HashMap, +} + +impl Graph { + fn get_node(&self, id: BlockId) -> &Node { + let idx = self.node_index[&id]; + &self.nodes[idx] + } +} + +// ============================================================================= +// Post-dominator tree computation +// ============================================================================= + +/// Compute the post-dominator tree for a function. +/// +/// If `include_throws_as_exit_node` is true, throw terminals are treated as +/// exit nodes (like return). Otherwise, only return terminals feed into exit. +pub fn compute_post_dominator_tree( + func: &HirFunction, + next_block_id_counter: u32, + include_throws_as_exit_node: bool, +) -> Result { + let graph = build_reverse_graph(func, next_block_id_counter, include_throws_as_exit_node); + let mut nodes = compute_immediate_dominators(&graph)?; + + // When include_throws_as_exit_node is false, nodes that flow into a throw + // terminal and don't reach the exit won't be in the node map. Add them + // with themselves as dominator. + if !include_throws_as_exit_node { + for (id, _) in &func.body.blocks { + nodes.entry(*id).or_insert(*id); + } + } + + Ok(PostDominator { + exit: graph.entry, + nodes, + }) +} + +/// Build the reverse graph from the HIR function. +/// +/// Reverses all edges and adds a synthetic exit node that receives edges from +/// return (and optionally throw) terminals. The result is put into RPO order. +fn build_reverse_graph( + func: &HirFunction, + next_block_id_counter: u32, + include_throws_as_exit_node: bool, +) -> Graph { + let exit_id = BlockId(next_block_id_counter); + + // Build initial nodes with reversed edges + let mut raw_nodes: HashMap = HashMap::new(); + + // Create exit node + raw_nodes.insert( + exit_id, + Node { + id: exit_id, + index: 0, + preds: HashSet::new(), + succs: HashSet::new(), + }, + ); + + for (id, block) in &func.body.blocks { + let successors = each_terminal_successor(&block.terminal); + let mut preds_set: HashSet = successors.into_iter().collect(); + let succs_set: HashSet = block.preds.iter().copied().collect(); + + let is_return = matches!(&block.terminal, Terminal::Return { .. }); + let is_throw = matches!(&block.terminal, Terminal::Throw { .. }); + + if is_return || (is_throw && include_throws_as_exit_node) { + preds_set.insert(exit_id); + raw_nodes.get_mut(&exit_id).unwrap().succs.insert(*id); + } + + raw_nodes.insert( + *id, + Node { + id: *id, + index: 0, + preds: preds_set, + succs: succs_set, + }, + ); + } + + // DFS from exit to compute RPO + let mut visited = HashSet::new(); + let mut postorder = Vec::new(); + dfs_postorder(exit_id, &raw_nodes, &mut visited, &mut postorder); + + // Reverse postorder + postorder.reverse(); + + let mut nodes = Vec::with_capacity(postorder.len()); + let mut node_index = HashMap::new(); + for (idx, id) in postorder.into_iter().enumerate() { + let mut node = raw_nodes.remove(&id).unwrap(); + node.index = idx; + node_index.insert(id, idx); + nodes.push(node); + } + + Graph { + entry: exit_id, + nodes, + node_index, + } +} + +fn dfs_postorder( + id: BlockId, + nodes: &HashMap, + visited: &mut HashSet, + postorder: &mut Vec, +) { + if !visited.insert(id) { + return; + } + if let Some(node) = nodes.get(&id) { + for &succ in &node.succs { + dfs_postorder(succ, nodes, visited, postorder); + } + } + postorder.push(id); +} + +// ============================================================================= +// Dominator fixpoint (Cooper/Harvey/Kennedy) +// ============================================================================= + +fn compute_immediate_dominators( + graph: &Graph, +) -> Result, CompilerDiagnostic> { + let mut doms: HashMap = HashMap::new(); + doms.insert(graph.entry, graph.entry); + + let mut changed = true; + while changed { + changed = false; + for node in &graph.nodes { + if node.id == graph.entry { + continue; + } + + // Find first processed predecessor + let mut new_idom: Option = None; + for &pred in &node.preds { + if doms.contains_key(&pred) { + new_idom = Some(pred); + break; + } + } + let mut new_idom = match new_idom { + Some(idom) => idom, + None => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "At least one predecessor must have been visited for block {:?}", + node.id + ), + None, + )); + } + }; + + // Intersect with other processed predecessors + for &pred in &node.preds { + if pred == new_idom { + continue; + } + if doms.contains_key(&pred) { + new_idom = intersect(pred, new_idom, graph, &doms); + } + } + + if doms.get(&node.id) != Some(&new_idom) { + doms.insert(node.id, new_idom); + changed = true; + } + } + } + Ok(doms) +} + +fn intersect(a: BlockId, b: BlockId, graph: &Graph, doms: &HashMap) -> BlockId { + let mut block1 = graph.get_node(a); + let mut block2 = graph.get_node(b); + while block1.id != block2.id { + while block1.index > block2.index { + let dom = doms[&block1.id]; + block1 = graph.get_node(dom); + } + while block2.index > block1.index { + let dom = doms[&block2.id]; + block2 = graph.get_node(dom); + } + } + block1.id +} + +// ============================================================================= +// Post-dominator frontier +// ============================================================================= + +/// Computes the post-dominator frontier of `target_id`. These are immediate +/// predecessors of nodes that post-dominate `target_id` from which execution +/// may not reach `target_id`. Intuitively, these are the earliest blocks from +/// which execution branches such that it may or may not reach the target block. +pub fn post_dominator_frontier( + func: &HirFunction, + post_dominators: &PostDominator, + target_id: BlockId, +) -> HashSet { + let target_post_dominators = post_dominators_of(func, post_dominators, target_id); + let mut visited = HashSet::new(); + let mut frontier = HashSet::new(); + + let mut to_visit: Vec = target_post_dominators.iter().copied().collect(); + to_visit.push(target_id); + + for block_id in to_visit { + if !visited.insert(block_id) { + continue; + } + if let Some(block) = func.body.blocks.get(&block_id) { + for &pred in &block.preds { + if !target_post_dominators.contains(&pred) { + frontier.insert(pred); + } + } + } + } + frontier +} + +/// Walks up the post-dominator tree to collect all blocks that post-dominate +/// `target_id`. +pub fn post_dominators_of( + func: &HirFunction, + post_dominators: &PostDominator, + target_id: BlockId, +) -> HashSet { + let mut result = HashSet::new(); + let mut visited = HashSet::new(); + let mut queue = vec![target_id]; + + while let Some(current_id) = queue.pop() { + if !visited.insert(current_id) { + continue; + } + if let Some(block) = func.body.blocks.get(¤t_id) { + for &pred in &block.preds { + let pred_post_dom = post_dominators.get(pred).unwrap_or(pred); + if pred_post_dom == target_id || result.contains(&pred_post_dom) { + result.insert(pred); + } + queue.push(pred); + } + } + } + result +} + +// ============================================================================= +// Unconditional blocks +// ============================================================================= + +/// Compute the set of blocks that are unconditionally executed from the entry. +/// +/// Port of ComputeUnconditionalBlocks.ts. Walks the immediate post-dominator +/// chain starting from the function entry. A block is unconditional if it lies +/// on this chain (meaning every path through the function must pass through +/// it). +pub fn compute_unconditional_blocks( + func: &HirFunction, + next_block_id_counter: u32, +) -> Result, CompilerDiagnostic> { + let mut unconditional = HashSet::new(); + let dominators = compute_post_dominator_tree(func, next_block_id_counter, false)?; + let exit = dominators.exit; + let mut current: Option = Some(func.body.entry); + + while let Some(block_id) = current { + if block_id == exit { + break; + } + assert!( + !unconditional.contains(&block_id), + "Internal error: non-terminating loop in ComputeUnconditionalBlocks" + ); + unconditional.insert(block_id); + current = dominators.get(block_id); + } + + Ok(unconditional) +} diff --git a/crates/react_compiler_hir/src/environment.rs b/crates/react_compiler_hir/src/environment.rs new file mode 100644 index 000000000000..a654ef033992 --- /dev/null +++ b/crates/react_compiler_hir/src/environment.rs @@ -0,0 +1,1095 @@ +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerError, CompilerErrorDetail, ErrorCategory, +}; + +use crate::{ + default_module_type_provider::default_module_type_provider, + environment_config::EnvironmentConfig, + globals::{self, Global, GlobalRegistry}, + object_shape::{ + add_hook, default_mutating_hook, default_nonmutating_hook, FunctionSignature, HookKind, + HookSignatureBuilder, ShapeRegistry, BUILT_IN_MIXED_READONLY_ID, + }, + *, +}; + +/// A variable rename from lowering: the binding at `declaration_start` position +/// was renamed from `original` to `renamed`. +#[derive(Debug, Clone)] +pub struct BindingRename { + pub original: String, + pub renamed: String, + pub declaration_start: u32, +} + +/// Output mode for the compiler, mirrored from the entrypoint's +/// CompilerOutputMode. Stored on Environment so pipeline passes can access it. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OutputMode { + Ssr, + Client, + Lint, +} + +pub struct Environment { + // Counters + pub next_block_id_counter: u32, + pub next_scope_id_counter: u32, + + // Arenas (use direct field access for sliced borrows) + pub identifiers: Vec, + pub types: Vec, + pub scopes: Vec, + pub functions: Vec, + + // Error accumulation + pub errors: CompilerError, + + // Function type classification (Component, Hook, Other) + pub fn_type: ReactFunctionType, + + // Output mode (Client, Ssr, Lint) + pub output_mode: OutputMode, + + // Source file code (for fast refresh hash computation) + pub code: Option, + + // Source file name (for instrumentation) + pub filename: Option, + + // Pre-resolved import local names for instrumentation/hook guards. + // Set by the program-level code before compilation. + pub instrument_fn_name: Option, + pub instrument_gating_name: Option, + pub hook_guard_name: Option, + + // Renames: tracks variable renames from lowering (original_name → new_name) + // keyed by binding declaration position, for applying back to the Babel AST. + pub renames: Vec, + + // Hoisted identifiers: tracks which bindings have already been hoisted + // via DeclareContext to avoid duplicate hoisting. + // Uses u32 to avoid depending on react_compiler_ast types. + hoisted_identifiers: HashSet, + + // Config flags for validation passes (kept for backwards compat with existing pipeline code) + pub validate_preserve_existing_memoization_guarantees: bool, + pub validate_no_set_state_in_render: bool, + pub enable_preserve_existing_memoization_guarantees: bool, + + // Type system registries + globals: GlobalRegistry, + pub shapes: ShapeRegistry, + module_types: HashMap>, + module_type_errors: HashMap>, + + // Environment configuration (feature flags, custom hooks, etc.) + pub config: EnvironmentConfig, + + // Cached default hook types (lazily initialized) + default_nonmutating_hook: Option, + default_mutating_hook: Option, + + // Outlined functions: functions extracted from the component during outlining passes + outlined_functions: Vec, + + // Counter for generating globally unique identifier names + uid_counter: u32, +} + +/// An outlined function entry, stored on Environment during compilation. +/// Corresponds to TS `{ fn: HIRFunction, type: ReactFunctionType | null }`. +#[derive(Debug, Clone)] +pub struct OutlinedFunctionEntry { + pub func: HirFunction, + pub fn_type: Option, +} + +impl Environment { + pub fn new() -> Self { + Self::with_config(EnvironmentConfig::default()) + } + + /// Create a new Environment with the given configuration. + /// + /// Initializes the shape and global registries, registers custom hooks, + /// and sets up the module type cache. + pub fn with_config(config: EnvironmentConfig) -> Self { + let mut shapes = ShapeRegistry::with_base(globals::base_shapes()); + let mut global_registry = GlobalRegistry::with_base(globals::base_globals()); + + // Register custom hooks from config + for (hook_name, hook) in &config.custom_hooks { + // Don't overwrite existing globals (matches TS invariant) + if global_registry.contains_key(hook_name) { + continue; + } + let return_type = if hook.transitive_mixed_data { + Type::Object { + shape_id: Some(BUILT_IN_MIXED_READONLY_ID.to_string()), + } + } else { + Type::Poly + }; + let hook_type = add_hook( + &mut shapes, + HookSignatureBuilder { + rest_param: Some(hook.effect_kind), + return_type, + return_value_kind: hook.value_kind, + hook_kind: HookKind::Custom, + no_alias: hook.no_alias, + ..Default::default() + }, + None, + ); + global_registry.insert(hook_name.clone(), hook_type); + } + + // Register reanimated module type when enabled + let mut module_types: HashMap> = HashMap::new(); + if config.enable_custom_type_definition_for_reanimated { + let reanimated_module_type = globals::get_reanimated_module_type(&mut shapes); + module_types.insert( + "react-native-reanimated".to_string(), + Some(reanimated_module_type), + ); + } + + Self { + next_block_id_counter: 0, + next_scope_id_counter: 0, + identifiers: Vec::new(), + types: Vec::new(), + scopes: Vec::new(), + functions: Vec::new(), + errors: CompilerError::new(), + fn_type: ReactFunctionType::Other, + output_mode: OutputMode::Client, + code: None, + filename: None, + instrument_fn_name: None, + instrument_gating_name: None, + hook_guard_name: None, + renames: Vec::new(), + hoisted_identifiers: HashSet::new(), + validate_preserve_existing_memoization_guarantees: config + .validate_preserve_existing_memoization_guarantees, + validate_no_set_state_in_render: config.validate_no_set_state_in_render, + enable_preserve_existing_memoization_guarantees: config + .enable_preserve_existing_memoization_guarantees, + globals: global_registry, + shapes, + module_types, + module_type_errors: HashMap::new(), + default_nonmutating_hook: None, + default_mutating_hook: None, + outlined_functions: Vec::new(), + uid_counter: 0, + config, + } + } + + /// Create a child Environment for compiling an outlined function. + /// + /// The child shares the same config, globals, and shapes, and receives + /// copies of all arenas (identifiers, types, scopes, functions) so that + /// references from the outlined HIR remain valid. Block/scope counters + /// start past the cloned data to avoid ID conflicts. + pub fn for_outlined_fn(&self, fn_type: ReactFunctionType) -> Self { + Self { + // Start block counter past any existing blocks in the outlined function. + // The outlined function has BlockId(0), parent may have more. Use parent's + // counter which is guaranteed to be > any block ID in the outlined function. + next_block_id_counter: self.next_block_id_counter, + // Scope counter must be consistent with scopes vec length + next_scope_id_counter: self.scopes.len() as u32, + identifiers: self.identifiers.clone(), + types: self.types.clone(), + scopes: self.scopes.clone(), + functions: self.functions.clone(), + errors: CompilerError::new(), + fn_type, + output_mode: self.output_mode, + code: self.code.clone(), + filename: self.filename.clone(), + instrument_fn_name: self.instrument_fn_name.clone(), + instrument_gating_name: self.instrument_gating_name.clone(), + hook_guard_name: self.hook_guard_name.clone(), + renames: Vec::new(), + hoisted_identifiers: HashSet::new(), + validate_preserve_existing_memoization_guarantees: self + .validate_preserve_existing_memoization_guarantees, + validate_no_set_state_in_render: self.validate_no_set_state_in_render, + enable_preserve_existing_memoization_guarantees: self + .enable_preserve_existing_memoization_guarantees, + globals: self.globals.clone(), + shapes: self.shapes.clone(), + module_types: self.module_types.clone(), + module_type_errors: self.module_type_errors.clone(), + config: self.config.clone(), + default_nonmutating_hook: self.default_nonmutating_hook.clone(), + default_mutating_hook: self.default_mutating_hook.clone(), + outlined_functions: Vec::new(), + uid_counter: self.uid_counter, + } + } + + pub fn next_block_id(&mut self) -> BlockId { + let id = BlockId(self.next_block_id_counter); + self.next_block_id_counter += 1; + id + } + + /// Allocate a new Identifier in the arena with default values, + /// returns its IdentifierId. + pub fn next_identifier_id(&mut self) -> IdentifierId { + let id = IdentifierId(self.identifiers.len() as u32); + let type_id = self.make_type(); + self.identifiers.push(Identifier { + id, + declaration_id: DeclarationId(id.0), + name: None, + mutable_range: MutableRange { + start: EvaluationOrder(0), + end: EvaluationOrder(0), + }, + scope: None, + type_: type_id, + loc: None, + }); + id + } + + /// Allocate a new ReactiveScope in the arena, returns its ScopeId. + pub fn next_scope_id(&mut self) -> ScopeId { + let id = ScopeId(self.next_scope_id_counter); + self.next_scope_id_counter += 1; + self.scopes.push(ReactiveScope { + id, + range: MutableRange { + start: EvaluationOrder(0), + end: EvaluationOrder(0), + }, + dependencies: Vec::new(), + declarations: Vec::new(), + reassignments: Vec::new(), + early_return_value: None, + merged: Vec::new(), + loc: None, + }); + id + } + + /// Allocate a new Type in the arena, returns its TypeId. + pub fn next_type_id(&mut self) -> TypeId { + let id = TypeId(self.types.len() as u32); + self.types.push(Type::TypeVar { id }); + id + } + + /// Allocate a new Type (TypeVar) in the arena, returns its TypeId. + pub fn make_type(&mut self) -> TypeId { + self.next_type_id() + } + + pub fn add_function(&mut self, func: HirFunction) -> FunctionId { + let id = FunctionId(self.functions.len() as u32); + self.functions.push(func); + id + } + + pub fn record_error(&mut self, detail: CompilerErrorDetail) -> Result<(), CompilerError> { + if detail.category == ErrorCategory::Invariant { + let detail_clone = detail.clone(); + self.errors.push_error_detail(detail); + let mut err = CompilerError::new(); + err.push_error_detail(detail_clone); + return Err(err); + } + self.errors.push_error_detail(detail); + Ok(()) + } + + pub fn record_diagnostic(&mut self, diagnostic: CompilerDiagnostic) { + self.errors.push_diagnostic(diagnostic); + } + + pub fn has_errors(&self) -> bool { + self.errors.has_any_errors() + } + + pub fn error_count(&self) -> usize { + self.errors.details.len() + } + + /// Check if any recorded errors have Invariant category. + /// In TS, Invariant errors throw immediately from recordError(), + /// which aborts the current operation. + pub fn has_invariant_errors(&self) -> bool { + self.errors.has_invariant_errors() + } + + pub fn errors(&self) -> &CompilerError { + &self.errors + } + + pub fn take_errors(&mut self) -> CompilerError { + let mut errors = std::mem::take(&mut self.errors); + // Mark as not thrown — these are accumulated errors returned at the end + // of the pipeline, not errors thrown by a pass. + errors.is_thrown = false; + errors + } + + /// Take errors added after position `since_count`, leaving earlier errors + /// in place. Used to detect new errors added by a specific pass. + pub fn take_errors_since(&mut self, since_count: usize) -> CompilerError { + let mut taken = CompilerError::new(); + if self.errors.details.len() > since_count { + taken.details = self.errors.details.split_off(since_count); + } + taken + } + + /// Take only the Invariant errors, leaving non-Invariant errors in place. + /// In TS, Invariant errors throw as a separate CompilerError, so only + /// the Invariant error is surfaced. + pub fn take_invariant_errors(&mut self) -> CompilerError { + let mut invariant = CompilerError::new(); + let mut remaining = CompilerError::new(); + let old = std::mem::take(&mut self.errors); + for detail in old.details { + let is_invariant = match &detail { + react_compiler_diagnostics::CompilerErrorOrDiagnostic::Diagnostic(d) => { + d.category == react_compiler_diagnostics::ErrorCategory::Invariant + } + react_compiler_diagnostics::CompilerErrorOrDiagnostic::ErrorDetail(d) => { + d.category == react_compiler_diagnostics::ErrorCategory::Invariant + } + }; + if is_invariant { + invariant.details.push(detail); + } else { + remaining.details.push(detail); + } + } + self.errors = remaining; + invariant + } + + /// Check if any recorded errors have Todo category. + /// In TS, Todo errors throw immediately via CompilerError.throwTodo(). + pub fn has_todo_errors(&self) -> bool { + self.errors.details.iter().any(|d| match d { + react_compiler_diagnostics::CompilerErrorOrDiagnostic::Diagnostic(d) => { + d.category == react_compiler_diagnostics::ErrorCategory::Todo + } + react_compiler_diagnostics::CompilerErrorOrDiagnostic::ErrorDetail(d) => { + d.category == react_compiler_diagnostics::ErrorCategory::Todo + } + }) + } + + /// Take errors that would have been thrown in TS (Invariant and Todo), + /// leaving other accumulated errors in place. + pub fn take_thrown_errors(&mut self) -> CompilerError { + let mut thrown = CompilerError::new(); + let mut remaining = CompilerError::new(); + let old = std::mem::take(&mut self.errors); + for detail in old.details { + let is_thrown = match &detail { + react_compiler_diagnostics::CompilerErrorOrDiagnostic::Diagnostic(d) => { + d.category == react_compiler_diagnostics::ErrorCategory::Invariant + || d.category == react_compiler_diagnostics::ErrorCategory::Todo + } + react_compiler_diagnostics::CompilerErrorOrDiagnostic::ErrorDetail(d) => { + d.category == react_compiler_diagnostics::ErrorCategory::Invariant + || d.category == react_compiler_diagnostics::ErrorCategory::Todo + } + }; + if is_thrown { + thrown.details.push(detail); + } else { + remaining.details.push(detail); + } + } + self.errors = remaining; + thrown + } + + /// Check if a binding has been hoisted (via DeclareContext) already. + pub fn is_hoisted_identifier(&self, binding_id: u32) -> bool { + self.hoisted_identifiers.contains(&binding_id) + } + + /// Mark a binding as hoisted. + pub fn add_hoisted_identifier(&mut self, binding_id: u32) { + self.hoisted_identifiers.insert(binding_id); + } + + // ========================================================================= + // Type resolution methods (ported from Environment.ts) + // ========================================================================= + + /// Resolve a non-local binding to its type. Ported from TS + /// `getGlobalDeclaration`. + /// + /// The `loc` parameter is used for error diagnostics when validating module + /// type configurations. Pass `None` if no source location is available. + pub fn get_global_declaration( + &mut self, + binding: &NonLocalBinding, + loc: Option, + ) -> Result, CompilerError> { + match binding { + NonLocalBinding::ModuleLocal { name, .. } => { + if is_hook_name(name) { + Ok(Some(self.get_custom_hook_type())) + } else { + Ok(None) + } + } + NonLocalBinding::Global { name, .. } => { + if let Some(ty) = self.globals.get(name) { + return Ok(Some(ty.clone())); + } + if is_hook_name(name) { + Ok(Some(self.get_custom_hook_type())) + } else { + Ok(None) + } + } + NonLocalBinding::ImportSpecifier { + name, + module, + imported, + } => { + if self.is_known_react_module(module) { + if let Some(ty) = self.globals.get(imported) { + return Ok(Some(ty.clone())); + } + if is_hook_name(imported) || is_hook_name(name) { + return Ok(Some(self.get_custom_hook_type())); + } + return Ok(None); + } + + // Try module type provider. We resolve first, then do property + // lookup on the cloned result to avoid double-borrow of self. + let module_type = self.resolve_module_type(module); + + // Check for module type validation errors (hook-name vs hook-type mismatches) + if let Some(errors) = self.module_type_errors.remove(module.as_str()) { + if let Some(first_error) = errors.into_iter().next() { + self.record_error( + CompilerErrorDetail::new( + ErrorCategory::Config, + "Invalid type configuration for module", + ) + .with_description(format!("{}", first_error)) + .with_loc(loc), + )?; + } + } + + if let Some(module_type) = module_type { + if let Some(imported_type) = + Self::get_property_type_from_shapes(&self.shapes, &module_type, imported) + { + return Ok(Some(imported_type)); + } + } + + if is_hook_name(imported) || is_hook_name(name) { + Ok(Some(self.get_custom_hook_type())) + } else { + Ok(None) + } + } + NonLocalBinding::ImportDefault { name, module } + | NonLocalBinding::ImportNamespace { name, module } => { + let is_default = matches!(binding, NonLocalBinding::ImportDefault { .. }); + + if self.is_known_react_module(module) { + if let Some(ty) = self.globals.get(name) { + return Ok(Some(ty.clone())); + } + if is_hook_name(name) { + return Ok(Some(self.get_custom_hook_type())); + } + return Ok(None); + } + + let module_type = self.resolve_module_type(module); + + // Check for module type validation errors (hook-name vs hook-type mismatches) + if let Some(errors) = self.module_type_errors.remove(module.as_str()) { + if let Some(first_error) = errors.into_iter().next() { + self.record_error( + CompilerErrorDetail::new( + ErrorCategory::Config, + "Invalid type configuration for module", + ) + .with_description(format!("{}", first_error)) + .with_loc(loc), + )?; + } + } + + if let Some(module_type) = module_type { + let imported_type = if is_default { + Self::get_property_type_from_shapes(&self.shapes, &module_type, "default") + } else { + Some(module_type) + }; + if let Some(imported_type) = imported_type { + // Validate hook-name vs hook-type consistency for module name + let expect_hook = is_hook_name(module); + let is_hook = self + .get_hook_kind_for_type(&imported_type) + .ok() + .flatten() + .is_some(); + if expect_hook != is_hook { + self.record_error( + CompilerErrorDetail::new( + ErrorCategory::Config, + "Invalid type configuration for module", + ) + .with_description(format!( + "Expected type for `import ... from '{}'` {} based on the \ + module name", + module, + if expect_hook { + "to be a hook" + } else { + "not to be a hook" + } + )) + .with_loc(loc), + )?; + } + return Ok(Some(imported_type)); + } + } + + if is_hook_name(name) { + Ok(Some(self.get_custom_hook_type())) + } else { + Ok(None) + } + } + } + } + + /// Static helper: resolve a property type using only the shapes registry. + /// Used internally to avoid double-borrow of `self`. Includes hook-name + /// fallback matching TS `getPropertyType`. + fn get_property_type_from_shapes( + shapes: &ShapeRegistry, + receiver: &Type, + property: &str, + ) -> Option { + let shape_id = match receiver { + Type::Object { shape_id } | Type::Function { shape_id, .. } => shape_id.as_deref(), + _ => None, + }; + if let Some(shape_id) = shape_id { + let shape = shapes.get(shape_id)?; + if let Some(ty) = shape.properties.get(property) { + return Some(ty.clone()); + } + if let Some(ty) = shape.properties.get("*") { + return Some(ty.clone()); + } + // Hook-name fallback: callers that need the custom hook type + // check is_hook_name after this returns None, which produces + // the same result as the TS getPropertyType hook-name fallback. + } + None + } + + /// Get the type of a named property on a receiver type. + /// Ported from TS `getPropertyType`. + pub fn get_property_type( + &mut self, + receiver: &Type, + property: &str, + ) -> Result, CompilerDiagnostic> { + let shape_id = match receiver { + Type::Object { shape_id } | Type::Function { shape_id, .. } => shape_id.as_deref(), + _ => None, + }; + if let Some(shape_id) = shape_id { + let shape = self.shapes.get(shape_id).ok_or_else(|| { + CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "[HIR] Forget internal error: cannot resolve shape {}", + shape_id + ), + None, + ) + })?; + if let Some(ty) = shape.properties.get(property) { + return Ok(Some(ty.clone())); + } + // Fall through to wildcard + if let Some(ty) = shape.properties.get("*") { + return Ok(Some(ty.clone())); + } + // If property name looks like a hook, return custom hook type + if is_hook_name(property) { + return Ok(Some(self.get_custom_hook_type())); + } + return Ok(None); + } + // No shape ID — if property looks like a hook, return custom hook type + if is_hook_name(property) { + return Ok(Some(self.get_custom_hook_type())); + } + Ok(None) + } + + /// Get the type of a numeric property on a receiver type. + /// Ported from the numeric branch of TS `getPropertyType`. + pub fn get_property_type_numeric( + &self, + receiver: &Type, + ) -> Result, CompilerDiagnostic> { + let shape_id = match receiver { + Type::Object { shape_id } | Type::Function { shape_id, .. } => shape_id.as_deref(), + _ => None, + }; + if let Some(shape_id) = shape_id { + let shape = self.shapes.get(shape_id).ok_or_else(|| { + CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "[HIR] Forget internal error: cannot resolve shape {}", + shape_id + ), + None, + ) + })?; + return Ok(shape.properties.get("*").cloned()); + } + Ok(None) + } + + /// Get the fallthrough (wildcard `*`) property type for computed property + /// access. Ported from TS `getFallthroughPropertyType`. + pub fn get_fallthrough_property_type( + &self, + receiver: &Type, + ) -> Result, CompilerDiagnostic> { + let shape_id = match receiver { + Type::Object { shape_id } | Type::Function { shape_id, .. } => shape_id.as_deref(), + _ => None, + }; + if let Some(shape_id) = shape_id { + let shape = self.shapes.get(shape_id).ok_or_else(|| { + CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "[HIR] Forget internal error: cannot resolve shape {}", + shape_id + ), + None, + ) + })?; + return Ok(shape.properties.get("*").cloned()); + } + Ok(None) + } + + /// Get the function signature for a function type. + /// Ported from TS `getFunctionSignature`. + pub fn get_function_signature( + &self, + ty: &Type, + ) -> Result, CompilerDiagnostic> { + let shape_id = match ty { + Type::Function { shape_id, .. } => shape_id.as_deref(), + _ => return Ok(None), + }; + if let Some(shape_id) = shape_id { + let shape = self.shapes.get(shape_id).ok_or_else(|| { + CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "[HIR] Forget internal error: cannot resolve shape {}", + shape_id + ), + None, + ) + })?; + return Ok(shape.function_type.as_ref()); + } + Ok(None) + } + + /// Get the hook kind for a type, if it represents a hook. + /// Ported from TS `getHookKindForType` in HIR.ts. + pub fn get_hook_kind_for_type( + &self, + ty: &Type, + ) -> Result, CompilerDiagnostic> { + Ok(self + .get_function_signature(ty)? + .and_then(|sig| sig.hook_kind.as_ref())) + } + + /// Resolve the module type provider for a given module name. + /// Caches results. Checks pre-resolved provider results first, then falls + /// back to `defaultModuleTypeProvider` (hardcoded). + fn resolve_module_type(&mut self, module_name: &str) -> Option { + if let Some(cached) = self.module_types.get(module_name) { + return cached.clone(); + } + + // Check pre-resolved provider results first, then fall back to default + let module_config = self + .config + .module_type_provider + .as_ref() + .and_then(|map| map.get(module_name).cloned()) + .or_else(|| default_module_type_provider(module_name)); + + let module_type = module_config.map(|config| { + let mut type_errors: Vec = Vec::new(); + let ty = globals::install_type_config_with_errors( + &mut self.globals, + &mut self.shapes, + &config, + module_name, + (), + &mut type_errors, + ); + // Store errors for later reporting when the import is actually used + for err in type_errors { + self.module_type_errors + .entry(module_name.to_string()) + .or_default() + .push(err); + } + ty + }); + self.module_types + .insert(module_name.to_string(), module_type.clone()); + module_type + } + + fn is_known_react_module(&self, module_name: &str) -> bool { + let lower = module_name.to_lowercase(); + lower == "react" || lower == "react-dom" + } + + fn get_custom_hook_type(&mut self) -> Global { + if self.config.enable_assume_hooks_follow_rules_of_react { + if self.default_nonmutating_hook.is_none() { + self.default_nonmutating_hook = Some(default_nonmutating_hook(&mut self.shapes)); + } + self.default_nonmutating_hook.clone().unwrap() + } else { + if self.default_mutating_hook.is_none() { + self.default_mutating_hook = Some(default_mutating_hook(&mut self.shapes)); + } + self.default_mutating_hook.clone().unwrap() + } + } + + /// Public accessor for the custom hook type, used by InferTypes for + /// property resolution fallback when a property name looks like a hook. + pub fn get_custom_hook_type_opt(&mut self) -> Option { + Some(self.get_custom_hook_type()) + } + + /// Get a reference to the shapes registry. + pub fn shapes(&self) -> &ShapeRegistry { + &self.shapes + } + + /// Get a reference to the globals registry. + pub fn globals(&self) -> &GlobalRegistry { + &self.globals + } + + /// Generate a globally unique identifier name, analogous to TS + /// `generateGloballyUniqueIdentifierName` which delegates to Babel's + /// `scope.generateUidIdentifier`. Matches Babel's naming convention: + /// first name is `_`, subsequent are `_2`, `_3`, etc. + /// Also applies Babel's `toIdentifier` sanitization on the input name. + pub fn generate_globally_unique_identifier_name(&mut self, name: Option<&str>) -> String { + let base = name.unwrap_or("temp"); + // Apply Babel's toIdentifier sanitization: + // 1. Replace non-identifier chars with '-' + // 2. Strip leading '-' and digits + // 3. CamelCase: replace '-' sequences + optional following char with uppercase + // of that char + let mut dashed = String::new(); + for c in base.chars() { + if c.is_ascii_alphanumeric() || c == '_' || c == '$' { + dashed.push(c); + } else { + dashed.push('-'); + } + } + // Strip leading dashes and digits + let trimmed = dashed.trim_start_matches(|c: char| c == '-' || c.is_ascii_digit()); + // CamelCase conversion: replace sequences of '-' followed by optional char with + // uppercase + let mut camel = String::new(); + let mut chars = trimmed.chars().peekable(); + while let Some(c) = chars.next() { + if c == '-' { + while chars.peek() == Some(&'-') { + chars.next(); + } + if let Some(next) = chars.next() { + for uc in next.to_uppercase() { + camel.push(uc); + } + } + } else { + camel.push(c); + } + } + if camel.is_empty() { + camel = "temp".to_string(); + } + // Strip leading '_' and trailing digits (Babel's generateUid behavior) + let stripped = camel.trim_start_matches('_'); + let stripped = stripped.trim_end_matches(|c: char| c.is_ascii_digit()); + let uid_base = if stripped.is_empty() { + "temp" + } else { + stripped + }; + + self.uid_counter += 1; + if self.uid_counter <= 1 { + format!("_{}", uid_base) + } else { + format!("_{}{}", uid_base, self.uid_counter) + } + } + + /// Record an outlined function (extracted during outlineFunctions or + /// outlineJSX). Corresponds to TS `env.outlineFunction(fn, type)`. + pub fn outline_function(&mut self, func: HirFunction, fn_type: Option) { + self.outlined_functions + .push(OutlinedFunctionEntry { func, fn_type }); + } + + /// Get the outlined functions accumulated during compilation. + pub fn get_outlined_functions(&self) -> &[OutlinedFunctionEntry] { + &self.outlined_functions + } + + /// Take the outlined functions, leaving the vec empty. + pub fn take_outlined_functions(&mut self) -> Vec { + std::mem::take(&mut self.outlined_functions) + } + + /// Whether memoization is enabled for this compilation. + /// Ported from TS `get enableMemoization()` in Environment.ts. + /// Returns true for client/lint modes, false for SSR. + pub fn enable_memoization(&self) -> bool { + match self.output_mode { + OutputMode::Client | OutputMode::Lint => true, + OutputMode::Ssr => false, + } + } + + /// Whether validations are enabled for this compilation. + /// Ported from TS `get enableValidations()` in Environment.ts. + pub fn enable_validations(&self) -> bool { + match self.output_mode { + OutputMode::Client | OutputMode::Lint | OutputMode::Ssr => true, + } + } + + // ========================================================================= + // Name resolution helpers + // ========================================================================= + + /// Get the user-visible name for an identifier. + /// + /// First checks the identifier's own name. If None, looks for another + /// identifier with the same `declaration_id` that has a name. This handles + /// SSA identifiers that don't carry names but share a declaration_id with + /// the original named identifier from lowering. + /// + /// This is analogous to `identifierName` on Babel's SourceLocation, + /// which the parser sets on every identifier node. + pub fn identifier_name_for_id(&self, id: IdentifierId) -> Option { + let ident = &self.identifiers[id.0 as usize]; + if let Some(name) = &ident.name { + return Some(name.value().to_string()); + } + // Fall back: find another identifier with the same declaration_id that has a + // Named name + let decl_id = ident.declaration_id; + for other in &self.identifiers { + if other.declaration_id == decl_id { + if let Some(IdentifierName::Named(name)) = &other.name { + return Some(name.clone()); + } + } + } + None + } + + // ========================================================================= + // ID-based type helper methods + // ========================================================================= + + /// Check whether the function type for an identifier has a noAlias + /// signature. Looks up the identifier's type and checks its function + /// signature. + pub fn has_no_alias_signature(&self, identifier_id: IdentifierId) -> bool { + let ty = &self.types[self.identifiers[identifier_id.0 as usize].type_.0 as usize]; + self.get_function_signature(ty) + .ok() + .flatten() + .map_or(false, |sig| sig.no_alias) + } + + /// Get the hook kind for an identifier, if its type represents a hook. + /// Looks up the identifier's type and delegates to + /// `get_hook_kind_for_type`. + pub fn get_hook_kind_for_id( + &self, + identifier_id: IdentifierId, + ) -> Result, CompilerDiagnostic> { + let ty = &self.types[self.identifiers[identifier_id.0 as usize].type_.0 as usize]; + self.get_hook_kind_for_type(ty) + } +} + +impl Default for Environment { + fn default() -> Self { + Self::new() + } +} + +/// Check if a name matches the React hook naming convention: `use[A-Z0-9]`. +/// Ported from TS `isHookName` in Environment.ts. +pub fn is_hook_name(name: &str) -> bool { + if name.len() < 4 { + return false; + } + if !name.starts_with("use") { + return false; + } + let fourth_char = name.as_bytes()[3]; + fourth_char.is_ascii_uppercase() || fourth_char.is_ascii_digit() +} + +/// Returns true if the name follows React naming conventions (component or +/// hook). Components start with an uppercase letter; hooks match `use[A-Z0-9]`. +pub fn is_react_like_name(name: &str) -> bool { + if name.is_empty() { + return false; + } + let first_char = name.as_bytes()[0]; + if first_char.is_ascii_uppercase() { + return true; + } + is_hook_name(name) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_hook_name() { + assert!(is_hook_name("useState")); + assert!(is_hook_name("useEffect")); + assert!(is_hook_name("useMyHook")); + assert!(is_hook_name("use3rdParty")); + assert!(!is_hook_name("use")); + assert!(!is_hook_name("used")); + assert!(!is_hook_name("useless")); + assert!(!is_hook_name("User")); + assert!(!is_hook_name("foo")); + } + + #[test] + fn test_environment_has_globals() { + let env = Environment::new(); + assert!(env.globals().contains_key("useState")); + assert!(env.globals().contains_key("useEffect")); + assert!(env.globals().contains_key("useRef")); + assert!(env.globals().contains_key("Math")); + assert!(env.globals().contains_key("console")); + assert!(env.globals().contains_key("Array")); + assert!(env.globals().contains_key("Object")); + } + + #[test] + fn test_get_property_type_array() { + let mut env = Environment::new(); + let array_type = Type::Object { + shape_id: Some("BuiltInArray".to_string()), + }; + let map_type = env.get_property_type(&array_type, "map").unwrap(); + assert!(map_type.is_some()); + let push_type = env.get_property_type(&array_type, "push").unwrap(); + assert!(push_type.is_some()); + let nonexistent = env + .get_property_type(&array_type, "nonExistentMethod") + .unwrap(); + assert!(nonexistent.is_none()); + } + + #[test] + fn test_get_function_signature() { + let env = Environment::new(); + let use_state_type = env.globals().get("useState").unwrap(); + let sig = env.get_function_signature(use_state_type).unwrap(); + assert!(sig.is_some()); + let sig = sig.unwrap(); + assert!(sig.hook_kind.is_some()); + assert_eq!(sig.hook_kind.as_ref().unwrap(), &HookKind::UseState); + } + + #[test] + fn test_get_global_declaration() { + let mut env = Environment::new(); + // Global binding + let binding = NonLocalBinding::Global { + name: "Math".to_string(), + }; + let result = env.get_global_declaration(&binding, None).unwrap(); + assert!(result.is_some()); + + // Import from react + let binding = NonLocalBinding::ImportSpecifier { + name: "useState".to_string(), + module: "react".to_string(), + imported: "useState".to_string(), + }; + let result = env.get_global_declaration(&binding, None).unwrap(); + assert!(result.is_some()); + + // Unknown global + let binding = NonLocalBinding::Global { + name: "unknownThing".to_string(), + }; + let result = env.get_global_declaration(&binding, None).unwrap(); + assert!(result.is_none()); + + // Hook-like name gets default hook type + let binding = NonLocalBinding::Global { + name: "useCustom".to_string(), + }; + let result = env.get_global_declaration(&binding, None).unwrap(); + assert!(result.is_some()); + } +} diff --git a/crates/react_compiler_hir/src/environment_config.rs b/crates/react_compiler_hir/src/environment_config.rs new file mode 100644 index 000000000000..18d910ea71aa --- /dev/null +++ b/crates/react_compiler_hir/src/environment_config.rs @@ -0,0 +1,236 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Environment configuration, ported from EnvironmentConfigSchema in +//! Environment.ts. +//! +//! Contains feature flags and custom hook definitions that control compiler +//! behavior. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::{ + type_config::{TypeConfig, ValueKind}, + Effect, +}; + +/// External function reference (source module + import name). +/// Corresponds to TS `ExternalFunction`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalFunctionConfig { + pub source: String, + pub import_specifier_name: String, +} + +/// Instrumentation configuration. +/// Corresponds to TS `InstrumentationSchema`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InstrumentationConfig { + #[serde(rename = "fn")] + pub fn_: ExternalFunctionConfig, + #[serde(default)] + pub gating: Option, + #[serde(default)] + pub global_gating: Option, +} + +/// Custom hook configuration, ported from TS `HookSchema`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookConfig { + pub effect_kind: Effect, + pub value_kind: ValueKind, + #[serde(default)] + pub no_alias: bool, + #[serde(default)] + pub transitive_mixed_data: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExhaustiveEffectDepsMode { + #[serde(rename = "off")] + Off, + #[serde(rename = "all")] + All, + #[serde(rename = "missing-only")] + MissingOnly, + #[serde(rename = "extra-only")] + ExtraOnly, +} + +impl Default for ExhaustiveEffectDepsMode { + fn default() -> Self { + Self::Off + } +} + +fn default_true() -> bool { + true +} + +/// Compiler environment configuration. Contains feature flags and settings. +/// +/// Fields that would require passing JS functions across the JS/Rust boundary +/// are omitted with TODO comments. The Rust port uses hardcoded defaults for +/// these (e.g., `defaultModuleTypeProvider`). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct EnvironmentConfig { + /// Custom hook type definitions, keyed by hook name. + #[serde(default)] + pub custom_hooks: HashMap, + + /// Pre-resolved module type provider results. + /// Map from module name to TypeConfig, computed by the JS shim. + #[serde(default)] + pub module_type_provider: Option>, + + /// Custom macro-like function names that should have their operands + /// memoized in the same scope (similar to fbt). + #[serde(default)] + pub custom_macros: Option>, + + /// If true, emit code to reset the memo cache on source file changes + /// (HMR/fast refresh). If null (None), HMR detection is conditionally + /// enabled based on NODE_ENV/__DEV__. + #[serde(default)] + pub enable_reset_cache_on_source_file_changes: Option, + + #[serde(default = "default_true")] + pub enable_preserve_existing_memoization_guarantees: bool, + #[serde(default = "default_true")] + pub validate_preserve_existing_memoization_guarantees: bool, + #[serde(default = "default_true")] + pub validate_exhaustive_memoization_dependencies: bool, + #[serde(default)] + pub validate_exhaustive_effect_dependencies: ExhaustiveEffectDepsMode, + + // TODO: flowTypeProvider — requires JS function callback. + #[serde(default = "default_true")] + pub enable_optional_dependencies: bool, + #[serde(default)] + pub enable_name_anonymous_functions: bool, + #[serde(default = "default_true")] + pub validate_hooks_usage: bool, + #[serde(default = "default_true")] + pub validate_ref_access_during_render: bool, + #[serde(default = "default_true")] + pub validate_no_set_state_in_render: bool, + #[serde(default)] + pub enable_use_keyed_state: bool, + #[serde(default)] + pub validate_no_set_state_in_effects: bool, + #[serde(default)] + pub validate_no_derived_computations_in_effects: bool, + #[serde(default)] + #[serde(alias = "validateNoDerivedComputationsInEffects_exp")] + pub validate_no_derived_computations_in_effects_exp: bool, + #[serde(default)] + #[serde(alias = "validateNoJSXInTryStatements")] + pub validate_no_jsx_in_try_statements: bool, + #[serde(default)] + pub validate_static_components: bool, + #[serde(default)] + pub validate_no_capitalized_calls: Option>, + #[serde(default)] + #[serde(alias = "restrictedImports")] + pub validate_blocklisted_imports: Option>, + #[serde(default)] + pub validate_source_locations: bool, + #[serde(default)] + pub validate_no_impure_functions_in_render: bool, + #[serde(default)] + pub validate_no_freezing_known_mutable_functions: bool, + #[serde(default = "default_true")] + pub enable_assume_hooks_follow_rules_of_react: bool, + #[serde(default = "default_true")] + pub enable_transitively_freeze_function_expressions: bool, + + /// Hook guard configuration. When set, wraps hook calls with dispatcher + /// guard calls. + #[serde(default)] + pub enable_emit_hook_guards: Option, + + /// Instrumentation configuration. When set, emits calls to instrument + /// functions. + #[serde(default)] + pub enable_emit_instrument_forget: Option, + + #[serde(default = "default_true")] + pub enable_function_outlining: bool, + #[serde(default)] + pub enable_jsx_outlining: bool, + #[serde(default)] + pub assert_valid_mutable_ranges: bool, + #[serde(default)] + #[serde(alias = "throwUnknownException__testonly")] + pub throw_unknown_exception_testonly: bool, + #[serde(default)] + pub enable_custom_type_definition_for_reanimated: bool, + #[serde(default = "default_true")] + pub enable_treat_ref_like_identifiers_as_refs: bool, + #[serde(default)] + pub enable_treat_set_identifiers_as_state_setters: bool, + #[serde(default = "default_true")] + pub validate_no_void_use_memo: bool, + #[serde(default = "default_true")] + pub enable_allow_set_state_from_refs_in_effects: bool, + #[serde(default)] + pub enable_verbose_no_set_state_in_effect: bool, + + // 🌲 + #[serde(default)] + pub enable_forest: bool, +} + +impl Default for EnvironmentConfig { + fn default() -> Self { + Self { + custom_hooks: HashMap::new(), + enable_reset_cache_on_source_file_changes: None, + module_type_provider: None, + enable_preserve_existing_memoization_guarantees: true, + validate_preserve_existing_memoization_guarantees: true, + validate_exhaustive_memoization_dependencies: true, + validate_exhaustive_effect_dependencies: ExhaustiveEffectDepsMode::Off, + enable_optional_dependencies: true, + enable_name_anonymous_functions: false, + validate_hooks_usage: true, + validate_ref_access_during_render: true, + validate_no_set_state_in_render: true, + enable_use_keyed_state: false, + validate_no_set_state_in_effects: false, + validate_no_derived_computations_in_effects: false, + validate_no_derived_computations_in_effects_exp: false, + validate_no_jsx_in_try_statements: false, + validate_static_components: false, + validate_no_capitalized_calls: None, + validate_blocklisted_imports: None, + validate_source_locations: false, + validate_no_impure_functions_in_render: false, + validate_no_freezing_known_mutable_functions: false, + enable_assume_hooks_follow_rules_of_react: true, + enable_transitively_freeze_function_expressions: true, + enable_emit_hook_guards: None, + enable_emit_instrument_forget: None, + enable_function_outlining: true, + enable_jsx_outlining: false, + assert_valid_mutable_ranges: false, + throw_unknown_exception_testonly: false, + enable_custom_type_definition_for_reanimated: false, + enable_treat_ref_like_identifiers_as_refs: true, + enable_treat_set_identifiers_as_state_setters: false, + validate_no_void_use_memo: true, + enable_allow_set_state_from_refs_in_effects: true, + enable_verbose_no_set_state_in_effect: false, + enable_forest: false, + custom_macros: None, + } + } +} diff --git a/crates/react_compiler_hir/src/globals.rs b/crates/react_compiler_hir/src/globals.rs new file mode 100644 index 000000000000..d9131e18b177 --- /dev/null +++ b/crates/react_compiler_hir/src/globals.rs @@ -0,0 +1,2582 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Global type registry and built-in shape definitions, ported from Globals.ts. +//! +//! Provides `DEFAULT_SHAPES` (built-in object shapes) and `DEFAULT_GLOBALS` +//! (global variable types including React hooks and JS built-ins). + +use std::{collections::HashMap, sync::LazyLock}; + +use crate::{ + object_shape::*, + type_config::{ + AliasingEffectConfig, AliasingSignatureConfig, ApplyArgConfig, ApplyArgHoleKind, + BuiltInTypeRef, TypeConfig, TypeReferenceConfig, ValueKind, ValueReason, + }, + Effect, Type, +}; + +/// Type alias matching TS `Global = BuiltInType | PolyType`. +/// In the Rust port, both map to our `Type` enum. +pub type Global = Type; + +/// Registry mapping global names to their types. +/// +/// Supports two modes: +/// - **Builder mode** (`base=None`): wraps a single HashMap, used during +/// `build_default_globals` to construct the static base. +/// - **Overlay mode** (`base=Some`): holds a `&'static HashMap` base plus a +/// small extras HashMap. Lookups check extras first, then base. Inserts go +/// into extras. Cloning only copies the extras map (the base pointer is +/// shared). +pub struct GlobalRegistry { + base: Option<&'static HashMap>, + entries: HashMap, +} + +impl GlobalRegistry { + /// Create an empty builder-mode registry. + pub fn new() -> Self { + Self { + base: None, + entries: HashMap::new(), + } + } + + /// Create an overlay-mode registry backed by a static base. + pub fn with_base(base: &'static HashMap) -> Self { + Self { + base: Some(base), + entries: HashMap::new(), + } + } + + pub fn get(&self, key: &str) -> Option<&Global> { + self.entries + .get(key) + .or_else(|| self.base.and_then(|b| b.get(key))) + } + + pub fn insert(&mut self, key: String, value: Global) { + self.entries.insert(key, value); + } + + pub fn contains_key(&self, key: &str) -> bool { + self.entries.contains_key(key) || self.base.map_or(false, |b| b.contains_key(key)) + } + + /// Iterate over all keys in the registry (base + extras). + /// Keys in extras that shadow base keys appear only once. + pub fn keys(&self) -> impl Iterator { + let base_keys = self + .base + .into_iter() + .flat_map(|b| b.keys()) + .filter(|k| !self.entries.contains_key(k.as_str())); + self.entries.keys().chain(base_keys) + } + + /// Consume the registry and return the inner HashMap. + /// Only valid in builder mode (no base). + pub fn into_inner(self) -> HashMap { + debug_assert!( + self.base.is_none(), + "into_inner() called on overlay-mode GlobalRegistry" + ); + self.entries + } +} + +impl Clone for GlobalRegistry { + fn clone(&self) -> Self { + Self { + base: self.base, + entries: self.entries.clone(), + } + } +} + +// ============================================================================= +// Static base registries (initialized once, shared across all Environments) +// ============================================================================= + +struct BaseRegistries { + shapes: HashMap, + globals: HashMap, +} + +static BASE: LazyLock = LazyLock::new(|| { + let mut shapes = build_builtin_shapes(); + let globals = build_default_globals(&mut shapes); + BaseRegistries { + shapes: shapes.into_inner(), + globals: globals.into_inner(), + } +}); + +/// Get a reference to the static base shapes registry. +pub fn base_shapes() -> &'static HashMap { + &BASE.shapes +} + +/// Get a reference to the static base globals registry. +pub fn base_globals() -> &'static HashMap { + &BASE.globals +} + +// ============================================================================= +// installTypeConfig — converts TypeConfig to internal Type +// ============================================================================= + +/// Convert a user-provided TypeConfig into an internal Type, registering shapes +/// as needed. Ported from TS `installTypeConfig` in Globals.ts. +/// If `errors` is provided, hook-name vs hook-type consistency validation +/// errors are collected there. +pub fn install_type_config( + _globals: &mut GlobalRegistry, + shapes: &mut ShapeRegistry, + type_config: &TypeConfig, + module_name: &str, + _loc: (), +) -> Global { + install_type_config_inner(_globals, shapes, type_config, module_name, _loc, &mut None) +} + +/// Like `install_type_config` but collects validation errors. +pub fn install_type_config_with_errors( + _globals: &mut GlobalRegistry, + shapes: &mut ShapeRegistry, + type_config: &TypeConfig, + module_name: &str, + _loc: (), + errors: &mut Vec, +) -> Global { + install_type_config_inner( + _globals, + shapes, + type_config, + module_name, + _loc, + &mut Some(errors), + ) +} + +fn install_type_config_inner( + _globals: &mut GlobalRegistry, + shapes: &mut ShapeRegistry, + type_config: &TypeConfig, + module_name: &str, + _loc: (), + errors: &mut Option<&mut Vec>, +) -> Global { + match type_config { + TypeConfig::TypeReference(TypeReferenceConfig { name }) => match name { + BuiltInTypeRef::Array => Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + BuiltInTypeRef::MixedReadonly => Type::Object { + shape_id: Some(BUILT_IN_MIXED_READONLY_ID.to_string()), + }, + BuiltInTypeRef::Primitive => Type::Primitive, + BuiltInTypeRef::Ref => Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + BuiltInTypeRef::Any => Type::Poly, + }, + TypeConfig::Function(func_config) => { + // Compute return type first to avoid double-borrow of shapes + let return_type = install_type_config_inner( + _globals, + shapes, + &func_config.return_type, + module_name, + (), + errors, + ); + add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: func_config.positional_params.clone(), + rest_param: func_config.rest_param, + callee_effect: func_config.callee_effect, + return_type, + return_value_kind: func_config.return_value_kind, + no_alias: func_config.no_alias.unwrap_or(false), + mutable_only_if_operands_are_mutable: func_config + .mutable_only_if_operands_are_mutable + .unwrap_or(false), + impure: func_config.impure.unwrap_or(false), + canonical_name: func_config.canonical_name.clone(), + aliasing: func_config.aliasing.clone(), + known_incompatible: func_config.known_incompatible.clone(), + ..Default::default() + }, + None, + false, + ) + } + TypeConfig::Hook(hook_config) => { + // Compute return type first to avoid double-borrow of shapes + let return_type = install_type_config_inner( + _globals, + shapes, + &hook_config.return_type, + module_name, + (), + errors, + ); + add_hook( + shapes, + HookSignatureBuilder { + hook_kind: HookKind::Custom, + positional_params: hook_config.positional_params.clone().unwrap_or_default(), + rest_param: hook_config.rest_param.or(Some(Effect::Freeze)), + callee_effect: Effect::Read, + return_type, + return_value_kind: hook_config.return_value_kind.unwrap_or(ValueKind::Frozen), + no_alias: hook_config.no_alias.unwrap_or(false), + aliasing: hook_config.aliasing.clone(), + known_incompatible: hook_config.known_incompatible.clone(), + ..Default::default() + }, + None, + ) + } + TypeConfig::Object(obj_config) => { + let properties: Vec<(String, Type)> = obj_config + .properties + .as_ref() + .map(|props| { + props + .iter() + .map(|(key, value)| { + let ty = install_type_config_inner( + _globals, + shapes, + value, + module_name, + (), + errors, + ); + // Validate hook-name vs hook-type consistency (matching TS + // installTypeConfig) + if let Some(errs) = errors { + let expect_hook = crate::environment::is_hook_name(key); + let is_hook = match &ty { + Type::Function { + shape_id: Some(id), .. + } => shapes + .get(id) + .and_then(|shape| shape.function_type.as_ref()) + .and_then(|ft| ft.hook_kind.as_ref()) + .is_some(), + _ => false, + }; + if expect_hook != is_hook { + errs.push(format!( + "Expected type for object property '{}' from module '{}' \ + {} based on the property name", + key, + module_name, + if expect_hook { + "to be a hook" + } else { + "not to be a hook" + } + )); + } + } + (key.clone(), ty) + }) + .collect() + }) + .unwrap_or_default(); + add_object(shapes, None, properties) + } + } +} + +// ============================================================================= +// Build built-in shapes (BUILTIN_SHAPES from ObjectShape.ts) +// ============================================================================= + +/// Build the built-in shapes registry. This corresponds to TS `BUILTIN_SHAPES` +/// defined at module level in ObjectShape.ts. +pub fn build_builtin_shapes() -> ShapeRegistry { + let mut shapes = ShapeRegistry::new(); + + // BuiltInProps: { ref: UseRefType } + add_object( + &mut shapes, + Some(BUILT_IN_PROPS_ID), + vec![( + "ref".to_string(), + Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + )], + ); + + build_array_shape(&mut shapes); + build_set_shape(&mut shapes); + build_map_shape(&mut shapes); + build_weak_set_shape(&mut shapes); + build_weak_map_shape(&mut shapes); + build_object_shape(&mut shapes); + build_ref_shapes(&mut shapes); + build_state_shapes(&mut shapes); + build_hook_shapes(&mut shapes); + build_misc_shapes(&mut shapes); + + shapes +} + +fn simple_function( + shapes: &mut ShapeRegistry, + positional_params: Vec, + rest_param: Option, + return_type: Type, + return_value_kind: ValueKind, +) -> Type { + add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params, + rest_param, + return_type, + return_value_kind, + ..Default::default() + }, + None, + false, + ) +} + +/// Shorthand for a pure function returning Primitive. +fn pure_primitive_fn(shapes: &mut ShapeRegistry) -> Type { + simple_function( + shapes, + Vec::new(), + Some(Effect::Read), + Type::Primitive, + ValueKind::Primitive, + ) +} + +fn build_array_shape(shapes: &mut ShapeRegistry) { + let index_of = pure_primitive_fn(shapes); + let includes = pure_primitive_fn(shapes); + let pop = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Store, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let at = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let concat = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + callee_effect: Effect::Capture, + ..Default::default() + }, + None, + false, + ); + let join = pure_primitive_fn(shapes); + let flat = simple_function( + shapes, + Vec::new(), + Some(Effect::Read), + Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + ValueKind::Mutable, + ); + let to_reversed = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let slice = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let map = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: vec!["@callback".to_string()], + rest: None, + returns: "@returns".to_string(), + temporaries: vec![ + "@item".to_string(), + "@callbackReturn".to_string(), + "@thisArg".to_string(), + ], + effects: vec![ + // Map creates a new mutable array + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Mutable, + reason: ValueReason::KnownReturnSignature, + }, + // The first arg to the callback is an item extracted from the receiver array + AliasingEffectConfig::CreateFrom { + from: "@receiver".to_string(), + into: "@item".to_string(), + }, + // The undefined this for the callback + AliasingEffectConfig::Create { + into: "@thisArg".to_string(), + value: ValueKind::Primitive, + reason: ValueReason::KnownReturnSignature, + }, + // Calls the callback, returning the result into a temporary + AliasingEffectConfig::Apply { + receiver: "@thisArg".to_string(), + function: "@callback".to_string(), + mutates_function: false, + args: vec![ + ApplyArgConfig::Place("@item".to_string()), + ApplyArgConfig::Hole { + kind: ApplyArgHoleKind::Hole, + }, + ApplyArgConfig::Place("@receiver".to_string()), + ], + into: "@callbackReturn".to_string(), + }, + // Captures the result of the callback into the return array + AliasingEffectConfig::Capture { + from: "@callbackReturn".to_string(), + into: "@returns".to_string(), + }, + ], + }), + ..Default::default() + }, + None, + false, + ); + let filter = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let find = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let find_index = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let find_last = find.clone(); + let find_last_index = find_index.clone(); + let reduce = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let reduce_right = reduce.clone(); + let for_each = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let every = for_each.clone(); + let some = for_each.clone(); + let flat_map = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let sort = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + rest_param: None, + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let to_sorted = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + rest_param: None, + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let to_spliced = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let push = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: Vec::new(), + rest: Some("@rest".to_string()), + returns: "@returns".to_string(), + temporaries: Vec::new(), + effects: vec![ + // Push directly mutates the array itself + AliasingEffectConfig::Mutate { + value: "@receiver".to_string(), + }, + // The arguments are captured into the array + AliasingEffectConfig::Capture { + from: "@rest".to_string(), + into: "@receiver".to_string(), + }, + // Returns the new length, a primitive + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Primitive, + reason: ValueReason::KnownReturnSignature, + }, + ], + }), + ..Default::default() + }, + None, + false, + ); + let length = Type::Primitive; + let reverse = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let fill = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let splice = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let unshift = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let keys = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let values = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let entries = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let to_string = pure_primitive_fn(shapes); + let last_index_of = pure_primitive_fn(shapes); + + add_object( + shapes, + Some(BUILT_IN_ARRAY_ID), + vec![ + ("indexOf".to_string(), index_of), + ("includes".to_string(), includes), + ("pop".to_string(), pop), + ("at".to_string(), at), + ("concat".to_string(), concat), + ("join".to_string(), join), + ("flat".to_string(), flat), + ("toReversed".to_string(), to_reversed), + ("slice".to_string(), slice), + ("map".to_string(), map), + ("filter".to_string(), filter), + ("find".to_string(), find), + ("findIndex".to_string(), find_index), + ("findLast".to_string(), find_last), + ("findLastIndex".to_string(), find_last_index), + ("reduce".to_string(), reduce), + ("reduceRight".to_string(), reduce_right), + ("forEach".to_string(), for_each), + ("every".to_string(), every), + ("some".to_string(), some), + ("flatMap".to_string(), flat_map), + ("sort".to_string(), sort), + ("toSorted".to_string(), to_sorted), + ("toSpliced".to_string(), to_spliced), + ("push".to_string(), push), + ("length".to_string(), length), + ("reverse".to_string(), reverse), + ("fill".to_string(), fill), + ("splice".to_string(), splice), + ("unshift".to_string(), unshift), + ("keys".to_string(), keys), + ("values".to_string(), values), + ("entries".to_string(), entries), + ("toString".to_string(), to_string), + ("lastIndexOf".to_string(), last_index_of), + ], + ); +} + +fn build_set_shape(shapes: &mut ShapeRegistry) { + let has = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let add = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: Vec::new(), + rest: Some("@rest".to_string()), + returns: "@returns".to_string(), + temporaries: Vec::new(), + effects: vec![ + // Set.add returns the receiver Set + AliasingEffectConfig::Assign { + from: "@receiver".to_string(), + into: "@returns".to_string(), + }, + // Set.add mutates the set itself + AliasingEffectConfig::Mutate { + value: "@receiver".to_string(), + }, + // Captures the rest params into the set + AliasingEffectConfig::Capture { + from: "@rest".to_string(), + into: "@receiver".to_string(), + }, + ], + }), + ..Default::default() + }, + None, + false, + ); + let clear = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let delete = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let size = Type::Primitive; + let difference = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let union = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let symmetrical_difference = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + callee_effect: Effect::Capture, + return_type: Type::Object { + shape_id: Some(BUILT_IN_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let is_subset_of = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Read, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let is_superset_of = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Read, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let for_each = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let values = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let keys = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let entries = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + + add_object( + shapes, + Some(BUILT_IN_SET_ID), + vec![ + ("add".to_string(), add), + ("clear".to_string(), clear), + ("delete".to_string(), delete), + ("has".to_string(), has), + ("size".to_string(), size), + ("difference".to_string(), difference), + ("union".to_string(), union), + ("symmetricalDifference".to_string(), symmetrical_difference), + ("isSubsetOf".to_string(), is_subset_of), + ("isSupersetOf".to_string(), is_superset_of), + ("forEach".to_string(), for_each), + ("values".to_string(), values), + ("keys".to_string(), keys), + ("entries".to_string(), entries), + ], + ); +} + +fn build_map_shape(shapes: &mut ShapeRegistry) { + let has = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let get = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let clear = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let set = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture, Effect::Capture], + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_MAP_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let delete = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let size = Type::Primitive; + let for_each = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let values = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let keys = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let entries = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + + add_object( + shapes, + Some(BUILT_IN_MAP_ID), + vec![ + ("has".to_string(), has), + ("get".to_string(), get), + ("set".to_string(), set), + ("clear".to_string(), clear), + ("delete".to_string(), delete), + ("size".to_string(), size), + ("forEach".to_string(), for_each), + ("values".to_string(), values), + ("keys".to_string(), keys), + ("entries".to_string(), entries), + ], + ); +} + +fn build_weak_set_shape(shapes: &mut ShapeRegistry) { + let has = pure_primitive_fn(shapes); + let add = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_WEAK_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let delete = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + + add_object( + shapes, + Some(BUILT_IN_WEAK_SET_ID), + vec![ + ("has".to_string(), has), + ("add".to_string(), add), + ("delete".to_string(), delete), + ], + ); +} + +fn build_weak_map_shape(shapes: &mut ShapeRegistry) { + let has = pure_primitive_fn(shapes); + let get = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Capture, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let set = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture, Effect::Capture], + callee_effect: Effect::Store, + return_type: Type::Object { + shape_id: Some(BUILT_IN_WEAK_MAP_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let delete = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + callee_effect: Effect::Store, + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + + add_object( + shapes, + Some(BUILT_IN_WEAK_MAP_ID), + vec![ + ("has".to_string(), has), + ("get".to_string(), get), + ("set".to_string(), set), + ("delete".to_string(), delete), + ], + ); +} + +fn build_object_shape(shapes: &mut ShapeRegistry) { + // BuiltInObject: has toString() returning Primitive (matches TS BuiltInObjectId + // shape) + let to_string = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + add_object( + shapes, + Some(BUILT_IN_OBJECT_ID), + vec![("toString".to_string(), to_string)], + ); + // BuiltInFunction: empty shape + add_object(shapes, Some(BUILT_IN_FUNCTION_ID), Vec::new()); + // BuiltInJsx: empty shape + add_object(shapes, Some(BUILT_IN_JSX_ID), Vec::new()); + // BuiltInMixedReadonly: has explicit method types + wildcard returning + // MixedReadonly (matches TS BuiltInMixedReadonlyId shape) + let mixed_to_string = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let mixed_index_of = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let mixed_includes = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let mixed_at = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + return_type: Type::Object { + shape_id: Some(BUILT_IN_MIXED_READONLY_ID.to_string()), + }, + callee_effect: Effect::Capture, + return_value_kind: ValueKind::Frozen, + ..Default::default() + }, + None, + false, + ); + let mixed_map = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Mutable, + no_alias: true, + ..Default::default() + }, + None, + false, + ); + let mixed_flat_map = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Mutable, + no_alias: true, + ..Default::default() + }, + None, + false, + ); + let mixed_filter = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Mutable, + no_alias: true, + ..Default::default() + }, + None, + false, + ); + let mixed_concat = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + callee_effect: Effect::Capture, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let mixed_slice = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + callee_effect: Effect::Capture, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let mixed_every = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Primitive, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let mixed_some = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Primitive, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let mixed_find = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Object { + shape_id: Some(BUILT_IN_MIXED_READONLY_ID.to_string()), + }, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Frozen, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let mixed_find_index = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Primitive, + callee_effect: Effect::ConditionallyMutate, + return_value_kind: ValueKind::Primitive, + no_alias: true, + mutable_only_if_operands_are_mutable: true, + ..Default::default() + }, + None, + false, + ); + let mixed_join = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let mut mixed_props = HashMap::new(); + mixed_props.insert("toString".to_string(), mixed_to_string); + mixed_props.insert("indexOf".to_string(), mixed_index_of); + mixed_props.insert("includes".to_string(), mixed_includes); + mixed_props.insert("at".to_string(), mixed_at); + mixed_props.insert("map".to_string(), mixed_map); + mixed_props.insert("flatMap".to_string(), mixed_flat_map); + mixed_props.insert("filter".to_string(), mixed_filter); + mixed_props.insert("concat".to_string(), mixed_concat); + mixed_props.insert("slice".to_string(), mixed_slice); + mixed_props.insert("every".to_string(), mixed_every); + mixed_props.insert("some".to_string(), mixed_some); + mixed_props.insert("find".to_string(), mixed_find); + mixed_props.insert("findIndex".to_string(), mixed_find_index); + mixed_props.insert("join".to_string(), mixed_join); + mixed_props.insert( + "*".to_string(), + Type::Object { + shape_id: Some(BUILT_IN_MIXED_READONLY_ID.to_string()), + }, + ); + shapes.insert( + BUILT_IN_MIXED_READONLY_ID.to_string(), + ObjectShape { + properties: mixed_props, + function_type: None, + }, + ); +} + +fn build_ref_shapes(shapes: &mut ShapeRegistry) { + // BuiltInUseRefId: { current: Object { shapeId: BuiltInRefValue } } + add_object( + shapes, + Some(BUILT_IN_USE_REF_ID), + vec![( + "current".to_string(), + Type::Object { + shape_id: Some(BUILT_IN_REF_VALUE_ID.to_string()), + }, + )], + ); + // BuiltInRefValue: { *: Object { shapeId: BuiltInRefValue } } + // (self-referencing) + add_object( + shapes, + Some(BUILT_IN_REF_VALUE_ID), + vec![( + "*".to_string(), + Type::Object { + shape_id: Some(BUILT_IN_REF_VALUE_ID.to_string()), + }, + )], + ); +} + +fn build_state_shapes(shapes: &mut ShapeRegistry) { + // BuiltInSetState: function that freezes its argument + let set_state = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + Some(BUILT_IN_SET_STATE_ID), + false, + ); + + // BuiltInUseState: object with [0] = Poly (state), [1] = setState function + add_object( + shapes, + Some(BUILT_IN_USE_STATE_ID), + vec![("0".to_string(), Type::Poly), ("1".to_string(), set_state)], + ); + + // BuiltInSetActionState + let set_action_state = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + Some(BUILT_IN_SET_ACTION_STATE_ID), + false, + ); + + // BuiltInUseActionState: [0] = Poly, [1] = setActionState function + add_object( + shapes, + Some(BUILT_IN_USE_ACTION_STATE_ID), + vec![ + ("0".to_string(), Type::Poly), + ("1".to_string(), set_action_state), + ], + ); + + // BuiltInDispatch + let dispatch = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + Some(BUILT_IN_DISPATCH_ID), + false, + ); + + // BuiltInUseReducer: [0] = Poly, [1] = dispatch function + add_object( + shapes, + Some(BUILT_IN_USE_REDUCER_ID), + vec![("0".to_string(), Type::Poly), ("1".to_string(), dispatch)], + ); + + // BuiltInStartTransition + let start_transition = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + // Note: TS uses restParam: null for startTransition + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + Some(BUILT_IN_START_TRANSITION_ID), + false, + ); + + // BuiltInUseTransition: [0] = Primitive (isPending), [1] = startTransition + // function + add_object( + shapes, + Some(BUILT_IN_USE_TRANSITION_ID), + vec![ + ("0".to_string(), Type::Primitive), + ("1".to_string(), start_transition), + ], + ); + + // BuiltInSetOptimistic + let set_optimistic = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + Some(BUILT_IN_SET_OPTIMISTIC_ID), + false, + ); + + // BuiltInUseOptimistic: [0] = Poly, [1] = setOptimistic function + add_object( + shapes, + Some(BUILT_IN_USE_OPTIMISTIC_ID), + vec![ + ("0".to_string(), Type::Poly), + ("1".to_string(), set_optimistic), + ], + ); +} + +fn build_hook_shapes(shapes: &mut ShapeRegistry) { + // BuiltInEffectEvent function shape (the return value of useEffectEvent) + add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + callee_effect: Effect::ConditionallyMutate, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + Some(BUILT_IN_EFFECT_EVENT_ID), + false, + ); +} + +fn build_misc_shapes(shapes: &mut ShapeRegistry) { + // ReanimatedSharedValue: empty properties (matching TS) + add_object(shapes, Some(REANIMATED_SHARED_VALUE_ID), Vec::new()); +} + +/// Build the reanimated module type. Ported from TS `getReanimatedModuleType`. +pub fn get_reanimated_module_type(shapes: &mut ShapeRegistry) -> Type { + let mut reanimated_type: Vec<(String, Type)> = Vec::new(); + + // hooks that freeze args and return frozen value + let frozen_hooks = [ + "useFrameCallback", + "useAnimatedStyle", + "useAnimatedProps", + "useAnimatedScrollHandler", + "useAnimatedReaction", + "useWorkletCallback", + ]; + for hook in &frozen_hooks { + let hook_type = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + no_alias: true, + hook_kind: HookKind::Custom, + ..Default::default() + }, + None, + ); + reanimated_type.push((hook.to_string(), hook_type)); + } + + // hooks that return a mutable value (modelled as shared value) + let mutable_hooks = ["useSharedValue", "useDerivedValue"]; + for hook in &mutable_hooks { + let hook_type = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Object { + shape_id: Some(REANIMATED_SHARED_VALUE_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + no_alias: true, + hook_kind: HookKind::Custom, + ..Default::default() + }, + None, + ); + reanimated_type.push((hook.to_string(), hook_type)); + } + + // functions that return mutable value + let funcs = [ + "withTiming", + "withSpring", + "createAnimatedPropAdapter", + "withDecay", + "withRepeat", + "runOnUI", + "executeOnUIRuntimeSync", + ]; + for func_name in &funcs { + let func_type = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + no_alias: true, + ..Default::default() + }, + None, + false, + ); + reanimated_type.push((func_name.to_string(), func_type)); + } + + add_object(shapes, None, reanimated_type) +} + +// ============================================================================= +// Build default globals (DEFAULT_GLOBALS from Globals.ts) +// ============================================================================= + +/// Build the default globals registry. This corresponds to TS +/// `DEFAULT_GLOBALS`. +/// +/// Requires a mutable reference to the shapes registry because some globals +/// (like Object.keys, Array.isArray) register new shapes. +pub fn build_default_globals(shapes: &mut ShapeRegistry) -> GlobalRegistry { + let mut globals = GlobalRegistry::new(); + + // React APIs — returns the list so we can reuse them for the React namespace + let react_apis = build_react_apis(shapes, &mut globals); + + // Untyped globals (treated as Poly) — must come before typed globals + // so typed definitions take priority (matching TS ordering) + for name in UNTYPED_GLOBALS { + globals.insert(name.to_string(), Type::Poly); + } + + // Typed JS globals (overwrites Poly entries from UNTYPED_GLOBALS). + // Returns the list of typed globals for use as globalThis/global properties. + let typed_globals = build_typed_globals(shapes, &mut globals, react_apis); + + // globalThis and global — populated with all typed globals as properties + // (matching TS: `addObject(DEFAULT_SHAPES, 'globalThis', TYPED_GLOBALS)`) + globals.insert( + "globalThis".to_string(), + add_object(shapes, Some("globalThis"), typed_globals.clone()), + ); + globals.insert( + "global".to_string(), + add_object(shapes, Some("global"), typed_globals), + ); + + globals +} + +const UNTYPED_GLOBALS: &[&str] = &[ + "Object", + "Function", + "RegExp", + "Date", + "Error", + "TypeError", + "RangeError", + "ReferenceError", + "SyntaxError", + "URIError", + "EvalError", + "DataView", + "Float32Array", + "Float64Array", + "Int8Array", + "Int16Array", + "Int32Array", + "WeakMap", + "Uint8Array", + "Uint8ClampedArray", + "Uint16Array", + "Uint32Array", + "ArrayBuffer", + "JSON", + "console", + "eval", +]; + +/// Build the React API types (REACT_APIS from TS). Returns the list of (name, +/// type) pairs so they can be reused as properties of the React namespace +/// object (matching TS behavior where the SAME type objects are used in both +/// DEFAULT_GLOBALS and the React namespace). +fn build_react_apis( + shapes: &mut ShapeRegistry, + globals: &mut GlobalRegistry, +) -> Vec<(String, Type)> { + let mut react_apis: Vec<(String, Type)> = Vec::new(); + + // useContext + let use_context = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + return_value_reason: Some(ValueReason::Context), + hook_kind: HookKind::UseContext, + ..Default::default() + }, + Some(BUILT_IN_USE_CONTEXT_HOOK_ID), + ); + react_apis.push(("useContext".to_string(), use_context)); + + // useState + let use_state = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_STATE_ID.to_string()), + }, + return_value_kind: ValueKind::Frozen, + return_value_reason: Some(ValueReason::State), + hook_kind: HookKind::UseState, + ..Default::default() + }, + None, + ); + react_apis.push(("useState".to_string(), use_state)); + + // useActionState + let use_action_state = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_ACTION_STATE_ID.to_string()), + }, + return_value_kind: ValueKind::Frozen, + return_value_reason: Some(ValueReason::State), + hook_kind: HookKind::UseActionState, + ..Default::default() + }, + None, + ); + react_apis.push(("useActionState".to_string(), use_action_state)); + + // useReducer + let use_reducer = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_REDUCER_ID.to_string()), + }, + return_value_kind: ValueKind::Frozen, + return_value_reason: Some(ValueReason::ReducerState), + hook_kind: HookKind::UseReducer, + ..Default::default() + }, + None, + ); + react_apis.push(("useReducer".to_string(), use_reducer)); + + // useRef + let use_ref = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Capture), + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + hook_kind: HookKind::UseRef, + ..Default::default() + }, + None, + ); + react_apis.push(("useRef".to_string(), use_ref)); + + // useImperativeHandle + let use_imperative_handle = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Primitive, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseImperativeHandle, + ..Default::default() + }, + None, + ); + react_apis.push(("useImperativeHandle".to_string(), use_imperative_handle)); + + // useMemo + let use_memo = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseMemo, + ..Default::default() + }, + None, + ); + react_apis.push(("useMemo".to_string(), use_memo)); + + // useCallback + let use_callback = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseCallback, + ..Default::default() + }, + None, + ); + react_apis.push(("useCallback".to_string(), use_callback)); + + // useEffect (with aliasing signature) + let use_effect = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Primitive, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseEffect, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: Vec::new(), + rest: Some("@rest".to_string()), + returns: "@returns".to_string(), + temporaries: vec!["@effect".to_string()], + effects: vec![ + AliasingEffectConfig::Freeze { + value: "@rest".to_string(), + reason: ValueReason::Effect, + }, + AliasingEffectConfig::Create { + into: "@effect".to_string(), + value: ValueKind::Frozen, + reason: ValueReason::KnownReturnSignature, + }, + AliasingEffectConfig::Capture { + from: "@rest".to_string(), + into: "@effect".to_string(), + }, + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Primitive, + reason: ValueReason::KnownReturnSignature, + }, + ], + }), + ..Default::default() + }, + Some(BUILT_IN_USE_EFFECT_HOOK_ID), + ); + react_apis.push(("useEffect".to_string(), use_effect)); + + // useLayoutEffect + let use_layout_effect = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseLayoutEffect, + ..Default::default() + }, + Some(BUILT_IN_USE_LAYOUT_EFFECT_HOOK_ID), + ); + react_apis.push(("useLayoutEffect".to_string(), use_layout_effect)); + + // useInsertionEffect + let use_insertion_effect = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseInsertionEffect, + ..Default::default() + }, + Some(BUILT_IN_USE_INSERTION_EFFECT_HOOK_ID), + ); + react_apis.push(("useInsertionEffect".to_string(), use_insertion_effect)); + + // useTransition + let use_transition = add_hook( + shapes, + HookSignatureBuilder { + rest_param: None, + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_TRANSITION_ID.to_string()), + }, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseTransition, + ..Default::default() + }, + None, + ); + react_apis.push(("useTransition".to_string(), use_transition)); + + // useOptimistic + let use_optimistic = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_OPTIMISTIC_ID.to_string()), + }, + return_value_kind: ValueKind::Frozen, + return_value_reason: Some(ValueReason::State), + hook_kind: HookKind::UseOptimistic, + ..Default::default() + }, + None, + ); + react_apis.push(("useOptimistic".to_string(), use_optimistic)); + + // use (not a hook, it's a function) + let use_fn = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + ..Default::default() + }, + Some(BUILT_IN_USE_OPERATOR_ID), + false, + ); + react_apis.push(("use".to_string(), use_fn)); + + // useEffectEvent + let use_effect_event = add_hook( + shapes, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Function { + shape_id: Some(BUILT_IN_EFFECT_EVENT_ID.to_string()), + return_type: Box::new(Type::Poly), + is_constructor: false, + }, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::UseEffectEvent, + ..Default::default() + }, + Some(BUILT_IN_USE_EFFECT_EVENT_ID), + ); + react_apis.push(("useEffectEvent".to_string(), use_effect_event)); + + // Insert all React APIs as standalone globals + for (name, ty) in &react_apis { + globals.insert(name.clone(), ty.clone()); + } + + react_apis +} + +/// Build typed globals and return them as a list for use as globalThis/global +/// properties. +fn build_typed_globals( + shapes: &mut ShapeRegistry, + globals: &mut GlobalRegistry, + react_apis: Vec<(String, Type)>, +) -> Vec<(String, Type)> { + let mut typed_globals: Vec<(String, Type)> = Vec::new(); + // Object + let obj_keys = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: vec!["@object".to_string()], + rest: None, + returns: "@returns".to_string(), + temporaries: Vec::new(), + effects: vec![ + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Mutable, + reason: ValueReason::KnownReturnSignature, + }, + // Only keys are captured, and keys are immutable + AliasingEffectConfig::ImmutableCapture { + from: "@object".to_string(), + into: "@returns".to_string(), + }, + ], + }), + ..Default::default() + }, + None, + false, + ); + let obj_from_entries = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::ConditionallyMutate], + return_type: Type::Object { + shape_id: Some(BUILT_IN_OBJECT_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let obj_entries = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: vec!["@object".to_string()], + rest: None, + returns: "@returns".to_string(), + temporaries: Vec::new(), + effects: vec![ + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Mutable, + reason: ValueReason::KnownReturnSignature, + }, + // Object values are captured into the return + AliasingEffectConfig::Capture { + from: "@object".to_string(), + into: "@returns".to_string(), + }, + ], + }), + ..Default::default() + }, + None, + false, + ); + let obj_values = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Capture], + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: vec!["@object".to_string()], + rest: None, + returns: "@returns".to_string(), + temporaries: Vec::new(), + effects: vec![ + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Mutable, + reason: ValueReason::KnownReturnSignature, + }, + // Object values are captured into the return + AliasingEffectConfig::Capture { + from: "@object".to_string(), + into: "@returns".to_string(), + }, + ], + }), + ..Default::default() + }, + None, + false, + ); + let object_global = add_object( + shapes, + Some("Object"), + vec![ + ("keys".to_string(), obj_keys), + ("fromEntries".to_string(), obj_from_entries), + ("entries".to_string(), obj_entries), + ("values".to_string(), obj_values), + ], + ); + typed_globals.push(("Object".to_string(), object_global.clone())); + globals.insert("Object".to_string(), object_global); + + // Array + let array_is_array = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::Read], + return_type: Type::Primitive, + return_value_kind: ValueKind::Primitive, + ..Default::default() + }, + None, + false, + ); + let array_from = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![ + Effect::ConditionallyMutateIterator, + Effect::ConditionallyMutate, + Effect::ConditionallyMutate, + ], + rest_param: Some(Effect::Read), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let array_of = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + let array_global = add_object( + shapes, + Some("Array"), + vec![ + ("isArray".to_string(), array_is_array), + ("from".to_string(), array_from), + ("of".to_string(), array_of), + ], + ); + typed_globals.push(("Array".to_string(), array_global.clone())); + globals.insert("Array".to_string(), array_global); + + // Math + let math_fns: Vec<(String, Type)> = [ + "max", "min", "trunc", "ceil", "floor", "pow", "round", "sqrt", "abs", "sign", "log", + "log2", "log10", + ] + .iter() + .map(|name| (name.to_string(), pure_primitive_fn(shapes))) + .collect(); + let mut math_props = math_fns; + math_props.push(("PI".to_string(), Type::Primitive)); + // Math.random is impure + let math_random = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + impure: true, + canonical_name: Some("Math.random".to_string()), + ..Default::default() + }, + None, + false, + ); + math_props.push(("random".to_string(), math_random)); + let math_global = add_object(shapes, Some("Math"), math_props); + typed_globals.push(("Math".to_string(), math_global.clone())); + globals.insert("Math".to_string(), math_global); + + // performance + let perf_now = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + impure: true, + canonical_name: Some("performance.now".to_string()), + ..Default::default() + }, + None, + false, + ); + let perf_global = add_object( + shapes, + Some("performance"), + vec![("now".to_string(), perf_now)], + ); + typed_globals.push(("performance".to_string(), perf_global.clone())); + globals.insert("performance".to_string(), perf_global); + + // Date + let date_now = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Read), + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + impure: true, + canonical_name: Some("Date.now".to_string()), + ..Default::default() + }, + None, + false, + ); + let date_global = add_object(shapes, Some("Date"), vec![("now".to_string(), date_now)]); + typed_globals.push(("Date".to_string(), date_global.clone())); + globals.insert("Date".to_string(), date_global); + + // console + let console_methods: Vec<(String, Type)> = ["error", "info", "log", "table", "trace", "warn"] + .iter() + .map(|name| (name.to_string(), pure_primitive_fn(shapes))) + .collect(); + let console_global = add_object(shapes, Some("console"), console_methods); + typed_globals.push(("console".to_string(), console_global.clone())); + globals.insert("console".to_string(), console_global); + + // Simple global functions returning Primitive + for name in &[ + "Boolean", + "Number", + "String", + "parseInt", + "parseFloat", + "isNaN", + "isFinite", + "encodeURI", + "encodeURIComponent", + "decodeURI", + "decodeURIComponent", + ] { + let f = pure_primitive_fn(shapes); + typed_globals.push((name.to_string(), f.clone())); + globals.insert(name.to_string(), f); + } + + // Primitive globals + typed_globals.push(("Infinity".to_string(), Type::Primitive)); + globals.insert("Infinity".to_string(), Type::Primitive); + typed_globals.push(("NaN".to_string(), Type::Primitive)); + globals.insert("NaN".to_string(), Type::Primitive); + + // Map, Set, WeakMap, WeakSet constructors + let map_ctor = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::ConditionallyMutateIterator], + return_type: Type::Object { + shape_id: Some(BUILT_IN_MAP_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + true, + ); + typed_globals.push(("Map".to_string(), map_ctor.clone())); + globals.insert("Map".to_string(), map_ctor); + + let set_ctor = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::ConditionallyMutateIterator], + return_type: Type::Object { + shape_id: Some(BUILT_IN_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + true, + ); + typed_globals.push(("Set".to_string(), set_ctor.clone())); + globals.insert("Set".to_string(), set_ctor); + + let weak_map_ctor = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::ConditionallyMutateIterator], + return_type: Type::Object { + shape_id: Some(BUILT_IN_WEAK_MAP_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + true, + ); + typed_globals.push(("WeakMap".to_string(), weak_map_ctor.clone())); + globals.insert("WeakMap".to_string(), weak_map_ctor); + + let weak_set_ctor = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + positional_params: vec![Effect::ConditionallyMutateIterator], + return_type: Type::Object { + shape_id: Some(BUILT_IN_WEAK_SET_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + true, + ); + typed_globals.push(("WeakSet".to_string(), weak_set_ctor.clone())); + globals.insert("WeakSet".to_string(), weak_set_ctor); + + // React global object — reuses the same REACT_APIS types (matching TS behavior + // where the same type objects are used as both standalone globals and React.* + // properties) + let react_create_element = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + ..Default::default() + }, + None, + false, + ); + let react_clone_element = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + ..Default::default() + }, + None, + false, + ); + let react_create_ref = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Capture), + return_type: Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + return_value_kind: ValueKind::Mutable, + ..Default::default() + }, + None, + false, + ); + + // Build React namespace properties from react_apis + React-specific functions + let mut react_props: Vec<(String, Type)> = react_apis; + react_props.push(("createElement".to_string(), react_create_element)); + react_props.push(("cloneElement".to_string(), react_clone_element)); + react_props.push(("createRef".to_string(), react_create_ref)); + + let react_global = add_object(shapes, None, react_props); + typed_globals.push(("React".to_string(), react_global.clone())); + globals.insert("React".to_string(), react_global); + + // _jsx (used by JSX transform) + let jsx_fn = add_function( + shapes, + Vec::new(), + FunctionSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + ..Default::default() + }, + None, + false, + ); + typed_globals.push(("_jsx".to_string(), jsx_fn.clone())); + globals.insert("_jsx".to_string(), jsx_fn); + + typed_globals +} diff --git a/crates/react_compiler_hir/src/lib.rs b/crates/react_compiler_hir/src/lib.rs new file mode 100644 index 000000000000..517638a8d1c8 --- /dev/null +++ b/crates/react_compiler_hir/src/lib.rs @@ -0,0 +1,1532 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod default_module_type_provider; +pub mod dominator; +pub mod environment; +pub mod environment_config; +pub mod globals; +pub mod object_shape; +pub mod print; +pub mod reactive; +pub mod type_config; +pub mod visitors; + +use indexmap::{IndexMap, IndexSet}; +pub use react_compiler_diagnostics::{ + CompilerDiagnostic, ErrorCategory, Position, SourceLocation, GENERATED_SOURCE, +}; +pub use reactive::*; + +// ============================================================================= +// ID newtypes +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct BlockId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct IdentifierId(pub u32); + +/// Index into the flat instruction table on HirFunction. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct InstructionId(pub u32); + +/// Evaluation order assigned to instructions and terminals during numbering. +/// This was previously called InstructionId in the TypeScript compiler. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct EvaluationOrder(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct DeclarationId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ScopeId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct TypeId(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct FunctionId(pub u32); + +// ============================================================================= +// FloatValue wrapper +// ============================================================================= + +/// Wrapper around f64 that stores raw bytes for deterministic equality and +/// hashing. This allows use in HashMap keys and ensures NaN == NaN (bitwise +/// comparison). +#[derive(Debug, Clone, Copy)] +pub struct FloatValue(u64); + +impl FloatValue { + pub fn new(value: f64) -> Self { + FloatValue(value.to_bits()) + } + + pub fn value(self) -> f64 { + f64::from_bits(self.0) + } +} + +impl From for FloatValue { + fn from(value: f64) -> Self { + FloatValue::new(value) + } +} + +impl From for f64 { + fn from(value: FloatValue) -> Self { + value.value() + } +} + +impl PartialEq for FloatValue { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for FloatValue {} + +impl std::hash::Hash for FloatValue { + fn hash(&self, state: &mut H) { + self.0.hash(state); + } +} + +impl std::fmt::Display for FloatValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.value()) + } +} + +// ============================================================================= +// Core HIR types +// ============================================================================= + +/// A function lowered to HIR form +#[derive(Debug, Clone)] +pub struct HirFunction { + pub loc: Option, + pub id: Option, + pub name_hint: Option, + pub fn_type: ReactFunctionType, + pub params: Vec, + pub return_type_annotation: Option, + pub returns: Place, + pub context: Vec, + pub body: HIR, + pub instructions: Vec, + pub generator: bool, + pub is_async: bool, + pub directives: Vec, + pub aliasing_effects: Option>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReactFunctionType { + Component, + Hook, + Other, +} + +#[derive(Debug, Clone)] +pub enum ParamPattern { + Place(Place), + Spread(SpreadPattern), +} + +/// The HIR control-flow graph +#[derive(Debug, Clone)] +pub struct HIR { + pub entry: BlockId, + pub blocks: IndexMap, +} + +/// Block kinds +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BlockKind { + Block, + Value, + Loop, + Sequence, + Catch, +} + +impl std::fmt::Display for BlockKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BlockKind::Block => write!(f, "block"), + BlockKind::Value => write!(f, "value"), + BlockKind::Loop => write!(f, "loop"), + BlockKind::Sequence => write!(f, "sequence"), + BlockKind::Catch => write!(f, "catch"), + } + } +} + +/// A basic block in the CFG +#[derive(Debug, Clone)] +pub struct BasicBlock { + pub kind: BlockKind, + pub id: BlockId, + pub instructions: Vec, + pub terminal: Terminal, + pub preds: IndexSet, + pub phis: Vec, +} + +/// Phi node for SSA +#[derive(Debug, Clone)] +pub struct Phi { + pub place: Place, + pub operands: IndexMap, +} + +// ============================================================================= +// Terminal enum +// ============================================================================= + +#[derive(Debug, Clone)] +pub enum Terminal { + Unsupported { + id: EvaluationOrder, + loc: Option, + }, + Unreachable { + id: EvaluationOrder, + loc: Option, + }, + Throw { + value: Place, + id: EvaluationOrder, + loc: Option, + }, + Return { + value: Place, + return_variant: ReturnVariant, + id: EvaluationOrder, + loc: Option, + effects: Option>, + }, + Goto { + block: BlockId, + variant: GotoVariant, + id: EvaluationOrder, + loc: Option, + }, + If { + test: Place, + consequent: BlockId, + alternate: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Branch { + test: Place, + consequent: BlockId, + alternate: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Switch { + test: Place, + cases: Vec, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + DoWhile { + loop_block: BlockId, + test: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + While { + test: BlockId, + loop_block: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + For { + init: BlockId, + test: BlockId, + update: Option, + loop_block: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + ForOf { + init: BlockId, + test: BlockId, + loop_block: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + ForIn { + init: BlockId, + loop_block: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Logical { + operator: LogicalOperator, + test: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Ternary { + test: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Optional { + optional: bool, + test: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Label { + block: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Sequence { + block: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + MaybeThrow { + continuation: BlockId, + handler: Option, + id: EvaluationOrder, + loc: Option, + effects: Option>, + }, + Try { + block: BlockId, + handler_binding: Option, + handler: BlockId, + fallthrough: BlockId, + id: EvaluationOrder, + loc: Option, + }, + Scope { + fallthrough: BlockId, + block: BlockId, + scope: ScopeId, + id: EvaluationOrder, + loc: Option, + }, + PrunedScope { + fallthrough: BlockId, + block: BlockId, + scope: ScopeId, + id: EvaluationOrder, + loc: Option, + }, +} + +impl Terminal { + /// Get the evaluation order of this terminal + pub fn evaluation_order(&self) -> EvaluationOrder { + match self { + Terminal::Unsupported { id, .. } + | Terminal::Unreachable { id, .. } + | Terminal::Throw { id, .. } + | Terminal::Return { id, .. } + | Terminal::Goto { id, .. } + | Terminal::If { id, .. } + | Terminal::Branch { id, .. } + | Terminal::Switch { id, .. } + | Terminal::DoWhile { id, .. } + | Terminal::While { id, .. } + | Terminal::For { id, .. } + | Terminal::ForOf { id, .. } + | Terminal::ForIn { id, .. } + | Terminal::Logical { id, .. } + | Terminal::Ternary { id, .. } + | Terminal::Optional { id, .. } + | Terminal::Label { id, .. } + | Terminal::Sequence { id, .. } + | Terminal::MaybeThrow { id, .. } + | Terminal::Try { id, .. } + | Terminal::Scope { id, .. } + | Terminal::PrunedScope { id, .. } => *id, + } + } + + /// Get the source location of this terminal + pub fn loc(&self) -> Option<&SourceLocation> { + match self { + Terminal::Unsupported { loc, .. } + | Terminal::Unreachable { loc, .. } + | Terminal::Throw { loc, .. } + | Terminal::Return { loc, .. } + | Terminal::Goto { loc, .. } + | Terminal::If { loc, .. } + | Terminal::Branch { loc, .. } + | Terminal::Switch { loc, .. } + | Terminal::DoWhile { loc, .. } + | Terminal::While { loc, .. } + | Terminal::For { loc, .. } + | Terminal::ForOf { loc, .. } + | Terminal::ForIn { loc, .. } + | Terminal::Logical { loc, .. } + | Terminal::Ternary { loc, .. } + | Terminal::Optional { loc, .. } + | Terminal::Label { loc, .. } + | Terminal::Sequence { loc, .. } + | Terminal::MaybeThrow { loc, .. } + | Terminal::Try { loc, .. } + | Terminal::Scope { loc, .. } + | Terminal::PrunedScope { loc, .. } => loc.as_ref(), + } + } + + /// Set the evaluation order of this terminal + pub fn set_evaluation_order(&mut self, new_id: EvaluationOrder) { + match self { + Terminal::Unsupported { id, .. } + | Terminal::Unreachable { id, .. } + | Terminal::Throw { id, .. } + | Terminal::Return { id, .. } + | Terminal::Goto { id, .. } + | Terminal::If { id, .. } + | Terminal::Branch { id, .. } + | Terminal::Switch { id, .. } + | Terminal::DoWhile { id, .. } + | Terminal::While { id, .. } + | Terminal::For { id, .. } + | Terminal::ForOf { id, .. } + | Terminal::ForIn { id, .. } + | Terminal::Logical { id, .. } + | Terminal::Ternary { id, .. } + | Terminal::Optional { id, .. } + | Terminal::Label { id, .. } + | Terminal::Sequence { id, .. } + | Terminal::MaybeThrow { id, .. } + | Terminal::Try { id, .. } + | Terminal::Scope { id, .. } + | Terminal::PrunedScope { id, .. } => *id = new_id, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReturnVariant { + Void, + Implicit, + Explicit, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum GotoVariant { + Break, + Continue, + Try, +} + +#[derive(Debug, Clone)] +pub struct Case { + pub test: Option, + pub block: BlockId, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LogicalOperator { + And, + Or, + NullishCoalescing, +} + +impl std::fmt::Display for LogicalOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LogicalOperator::And => write!(f, "&&"), + LogicalOperator::Or => write!(f, "||"), + LogicalOperator::NullishCoalescing => write!(f, "??"), + } + } +} + +// ============================================================================= +// Instruction types +// ============================================================================= + +#[derive(Debug, Clone)] +pub struct Instruction { + pub id: EvaluationOrder, + pub lvalue: Place, + pub value: InstructionValue, + pub loc: Option, + pub effects: Option>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum InstructionKind { + Const, + Let, + Reassign, + Catch, + HoistedConst, + HoistedLet, + HoistedFunction, + Function, +} + +#[derive(Debug, Clone)] +pub struct LValue { + pub place: Place, + pub kind: InstructionKind, +} + +#[derive(Debug, Clone)] +pub struct LValuePattern { + pub pattern: Pattern, + pub kind: InstructionKind, +} + +#[derive(Debug, Clone)] +pub enum Pattern { + Array(ArrayPattern), + Object(ObjectPattern), +} + +// ============================================================================= +// InstructionValue enum +// ============================================================================= + +#[derive(Debug, Clone)] +pub enum InstructionValue { + LoadLocal { + place: Place, + loc: Option, + }, + LoadContext { + place: Place, + loc: Option, + }, + DeclareLocal { + lvalue: LValue, + type_annotation: Option, + loc: Option, + }, + DeclareContext { + lvalue: LValue, + loc: Option, + }, + StoreLocal { + lvalue: LValue, + value: Place, + type_annotation: Option, + loc: Option, + }, + StoreContext { + lvalue: LValue, + value: Place, + loc: Option, + }, + Destructure { + lvalue: LValuePattern, + value: Place, + loc: Option, + }, + Primitive { + value: PrimitiveValue, + loc: Option, + }, + JSXText { + value: String, + loc: Option, + }, + BinaryExpression { + operator: BinaryOperator, + left: Place, + right: Place, + loc: Option, + }, + NewExpression { + callee: Place, + args: Vec, + loc: Option, + }, + CallExpression { + callee: Place, + args: Vec, + loc: Option, + }, + MethodCall { + receiver: Place, + property: Place, + args: Vec, + loc: Option, + }, + UnaryExpression { + operator: UnaryOperator, + value: Place, + loc: Option, + }, + TypeCastExpression { + value: Place, + type_: Type, + type_annotation_name: Option, + type_annotation_kind: Option, + /// The original AST type annotation node, preserved for codegen. + /// For Flow: the inner type from TypeAnnotation.typeAnnotation + /// For TS: the TSType node from TSAsExpression/TSSatisfiesExpression + type_annotation: Option>, + loc: Option, + }, + JsxExpression { + tag: JsxTag, + props: Vec, + children: Option>, + loc: Option, + opening_loc: Option, + closing_loc: Option, + }, + ObjectExpression { + properties: Vec, + loc: Option, + }, + ObjectMethod { + loc: Option, + lowered_func: LoweredFunction, + }, + ArrayExpression { + elements: Vec, + loc: Option, + }, + JsxFragment { + children: Vec, + loc: Option, + }, + RegExpLiteral { + pattern: String, + flags: String, + loc: Option, + }, + MetaProperty { + meta: String, + property: String, + loc: Option, + }, + PropertyStore { + object: Place, + property: PropertyLiteral, + value: Place, + loc: Option, + }, + PropertyLoad { + object: Place, + property: PropertyLiteral, + loc: Option, + }, + PropertyDelete { + object: Place, + property: PropertyLiteral, + loc: Option, + }, + ComputedStore { + object: Place, + property: Place, + value: Place, + loc: Option, + }, + ComputedLoad { + object: Place, + property: Place, + loc: Option, + }, + ComputedDelete { + object: Place, + property: Place, + loc: Option, + }, + LoadGlobal { + binding: NonLocalBinding, + loc: Option, + }, + StoreGlobal { + name: String, + value: Place, + loc: Option, + }, + FunctionExpression { + name: Option, + name_hint: Option, + lowered_func: LoweredFunction, + expr_type: FunctionExpressionType, + loc: Option, + }, + TaggedTemplateExpression { + tag: Place, + value: TemplateQuasi, + loc: Option, + }, + TemplateLiteral { + subexprs: Vec, + quasis: Vec, + loc: Option, + }, + Await { + value: Place, + loc: Option, + }, + GetIterator { + collection: Place, + loc: Option, + }, + IteratorNext { + iterator: Place, + collection: Place, + loc: Option, + }, + NextPropertyOf { + value: Place, + loc: Option, + }, + PrefixUpdate { + lvalue: Place, + operation: UpdateOperator, + value: Place, + loc: Option, + }, + PostfixUpdate { + lvalue: Place, + operation: UpdateOperator, + value: Place, + loc: Option, + }, + Debugger { + loc: Option, + }, + StartMemoize { + manual_memo_id: u32, + deps: Option>, + deps_loc: Option>, + has_invalid_deps: bool, + loc: Option, + }, + FinishMemoize { + manual_memo_id: u32, + decl: Place, + pruned: bool, + loc: Option, + }, + UnsupportedNode { + node_type: Option, + /// The original AST node serialized as JSON, so codegen can emit it + /// verbatim. + original_node: Option, + loc: Option, + }, +} + +impl InstructionValue { + pub fn loc(&self) -> Option<&SourceLocation> { + match self { + InstructionValue::LoadLocal { loc, .. } + | InstructionValue::LoadContext { loc, .. } + | InstructionValue::DeclareLocal { loc, .. } + | InstructionValue::DeclareContext { loc, .. } + | InstructionValue::StoreLocal { loc, .. } + | InstructionValue::StoreContext { loc, .. } + | InstructionValue::Destructure { loc, .. } + | InstructionValue::Primitive { loc, .. } + | InstructionValue::JSXText { loc, .. } + | InstructionValue::BinaryExpression { loc, .. } + | InstructionValue::NewExpression { loc, .. } + | InstructionValue::CallExpression { loc, .. } + | InstructionValue::MethodCall { loc, .. } + | InstructionValue::UnaryExpression { loc, .. } + | InstructionValue::TypeCastExpression { loc, .. } + | InstructionValue::JsxExpression { loc, .. } + | InstructionValue::ObjectExpression { loc, .. } + | InstructionValue::ObjectMethod { loc, .. } + | InstructionValue::ArrayExpression { loc, .. } + | InstructionValue::JsxFragment { loc, .. } + | InstructionValue::RegExpLiteral { loc, .. } + | InstructionValue::MetaProperty { loc, .. } + | InstructionValue::PropertyStore { loc, .. } + | InstructionValue::PropertyLoad { loc, .. } + | InstructionValue::PropertyDelete { loc, .. } + | InstructionValue::ComputedStore { loc, .. } + | InstructionValue::ComputedLoad { loc, .. } + | InstructionValue::ComputedDelete { loc, .. } + | InstructionValue::LoadGlobal { loc, .. } + | InstructionValue::StoreGlobal { loc, .. } + | InstructionValue::FunctionExpression { loc, .. } + | InstructionValue::TaggedTemplateExpression { loc, .. } + | InstructionValue::TemplateLiteral { loc, .. } + | InstructionValue::Await { loc, .. } + | InstructionValue::GetIterator { loc, .. } + | InstructionValue::IteratorNext { loc, .. } + | InstructionValue::NextPropertyOf { loc, .. } + | InstructionValue::PrefixUpdate { loc, .. } + | InstructionValue::PostfixUpdate { loc, .. } + | InstructionValue::Debugger { loc, .. } + | InstructionValue::StartMemoize { loc, .. } + | InstructionValue::FinishMemoize { loc, .. } + | InstructionValue::UnsupportedNode { loc, .. } => loc.as_ref(), + } + } +} + +// ============================================================================= +// Supporting types +// ============================================================================= + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PrimitiveValue { + Null, + Undefined, + Boolean(bool), + Number(FloatValue), + String(String), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinaryOperator { + Equal, + NotEqual, + StrictEqual, + StrictNotEqual, + LessThan, + LessEqual, + GreaterThan, + GreaterEqual, + ShiftLeft, + ShiftRight, + UnsignedShiftRight, + Add, + Subtract, + Multiply, + Divide, + Modulo, + Exponent, + BitwiseOr, + BitwiseXor, + BitwiseAnd, + In, + InstanceOf, +} + +impl std::fmt::Display for BinaryOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + BinaryOperator::Equal => write!(f, "=="), + BinaryOperator::NotEqual => write!(f, "!="), + BinaryOperator::StrictEqual => write!(f, "==="), + BinaryOperator::StrictNotEqual => write!(f, "!=="), + BinaryOperator::LessThan => write!(f, "<"), + BinaryOperator::LessEqual => write!(f, "<="), + BinaryOperator::GreaterThan => write!(f, ">"), + BinaryOperator::GreaterEqual => write!(f, ">="), + BinaryOperator::ShiftLeft => write!(f, "<<"), + BinaryOperator::ShiftRight => write!(f, ">>"), + BinaryOperator::UnsignedShiftRight => write!(f, ">>>"), + BinaryOperator::Add => write!(f, "+"), + BinaryOperator::Subtract => write!(f, "-"), + BinaryOperator::Multiply => write!(f, "*"), + BinaryOperator::Divide => write!(f, "/"), + BinaryOperator::Modulo => write!(f, "%"), + BinaryOperator::Exponent => write!(f, "**"), + BinaryOperator::BitwiseOr => write!(f, "|"), + BinaryOperator::BitwiseXor => write!(f, "^"), + BinaryOperator::BitwiseAnd => write!(f, "&"), + BinaryOperator::In => write!(f, "in"), + BinaryOperator::InstanceOf => write!(f, "instanceof"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UnaryOperator { + Minus, + Plus, + Not, + BitwiseNot, + TypeOf, + Void, +} + +impl std::fmt::Display for UnaryOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UnaryOperator::Minus => write!(f, "-"), + UnaryOperator::Plus => write!(f, "+"), + UnaryOperator::Not => write!(f, "!"), + UnaryOperator::BitwiseNot => write!(f, "~"), + UnaryOperator::TypeOf => write!(f, "typeof"), + UnaryOperator::Void => write!(f, "void"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum UpdateOperator { + Increment, + Decrement, +} + +impl std::fmt::Display for UpdateOperator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + UpdateOperator::Increment => write!(f, "++"), + UpdateOperator::Decrement => write!(f, "--"), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum FunctionExpressionType { + ArrowFunctionExpression, + FunctionExpression, + FunctionDeclaration, +} + +#[derive(Debug, Clone)] +pub struct TemplateQuasi { + pub raw: String, + pub cooked: Option, +} + +#[derive(Debug, Clone)] +pub struct ManualMemoDependency { + pub root: ManualMemoDependencyRoot, + pub path: Vec, + pub loc: Option, +} + +#[derive(Debug, Clone)] +pub enum ManualMemoDependencyRoot { + NamedLocal { value: Place, constant: bool }, + Global { identifier_name: String }, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DependencyPathEntry { + pub property: PropertyLiteral, + pub optional: bool, + pub loc: Option, +} + +// ============================================================================= +// Place, Identifier, and related types +// ============================================================================= + +#[derive(Debug, Clone)] +pub struct Place { + pub identifier: IdentifierId, + pub effect: Effect, + pub reactive: bool, + pub loc: Option, +} + +#[derive(Debug, Clone)] +pub struct Identifier { + pub id: IdentifierId, + pub declaration_id: DeclarationId, + pub name: Option, + pub mutable_range: MutableRange, + pub scope: Option, + pub type_: TypeId, + pub loc: Option, +} + +#[derive(Debug, Clone)] +pub struct MutableRange { + pub start: EvaluationOrder, + pub end: EvaluationOrder, +} + +impl MutableRange { + /// Returns true if the given evaluation order falls within this mutable + /// range. Corresponds to TS `inRange({id}, range)` / `isMutable(instr, + /// place)`. + pub fn contains(&self, id: EvaluationOrder) -> bool { + id >= self.start && id < self.end + } +} + +#[derive(Debug, Clone)] +pub enum IdentifierName { + Named(String), + Promoted(String), +} + +impl IdentifierName { + pub fn value(&self) -> &str { + match self { + IdentifierName::Named(v) | IdentifierName::Promoted(v) => v, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum Effect { + #[serde(rename = "")] + Unknown, + #[serde(rename = "freeze")] + Freeze, + #[serde(rename = "read")] + Read, + #[serde(rename = "capture")] + Capture, + #[serde(rename = "mutate-iterator?")] + ConditionallyMutateIterator, + #[serde(rename = "mutate?")] + ConditionallyMutate, + #[serde(rename = "mutate")] + Mutate, + #[serde(rename = "store")] + Store, +} + +impl Effect { + /// Returns true if this effect represents a mutable operation. + /// Mutable effects are: Capture, Store, ConditionallyMutate, + /// ConditionallyMutateIterator, and Mutate. + pub fn is_mutable(&self) -> bool { + matches!( + self, + Effect::Capture + | Effect::Store + | Effect::ConditionallyMutate + | Effect::ConditionallyMutateIterator + | Effect::Mutate + ) + } +} + +impl std::fmt::Display for Effect { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Effect::Unknown => write!(f, ""), + Effect::Freeze => write!(f, "freeze"), + Effect::Read => write!(f, "read"), + Effect::Capture => write!(f, "capture"), + Effect::ConditionallyMutateIterator => write!(f, "mutate-iterator?"), + Effect::ConditionallyMutate => write!(f, "mutate?"), + Effect::Mutate => write!(f, "mutate"), + Effect::Store => write!(f, "store"), + } + } +} + +#[derive(Debug, Clone)] +pub struct SpreadPattern { + pub place: Place, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Hole { + Hole, +} + +#[derive(Debug, Clone)] +pub struct ArrayPattern { + pub items: Vec, + pub loc: Option, +} + +#[derive(Debug, Clone)] +pub enum ArrayPatternElement { + Place(Place), + Spread(SpreadPattern), + Hole, +} + +#[derive(Debug, Clone)] +pub struct ObjectPattern { + pub properties: Vec, + pub loc: Option, +} + +#[derive(Debug, Clone)] +pub enum ObjectPropertyOrSpread { + Property(ObjectProperty), + Spread(SpreadPattern), +} + +#[derive(Debug, Clone)] +pub struct ObjectProperty { + pub key: ObjectPropertyKey, + pub property_type: ObjectPropertyType, + pub place: Place, +} + +#[derive(Debug, Clone)] +pub enum ObjectPropertyKey { + String { name: String }, + Identifier { name: String }, + Computed { name: Place }, + Number { name: FloatValue }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ObjectPropertyType { + Property, + Method, +} + +impl std::fmt::Display for ObjectPropertyType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ObjectPropertyType::Property => write!(f, "property"), + ObjectPropertyType::Method => write!(f, "method"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum PropertyLiteral { + String(String), + Number(FloatValue), +} + +impl std::fmt::Display for PropertyLiteral { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + PropertyLiteral::String(s) => write!(f, "{}", s), + PropertyLiteral::Number(n) => write!(f, "{}", n), + } + } +} + +#[derive(Debug, Clone)] +pub enum PlaceOrSpread { + Place(Place), + Spread(SpreadPattern), +} + +#[derive(Debug, Clone)] +pub enum ArrayElement { + Place(Place), + Spread(SpreadPattern), + Hole, +} + +#[derive(Debug, Clone)] +pub struct LoweredFunction { + pub func: FunctionId, +} + +#[derive(Debug, Clone)] +pub struct BuiltinTag { + pub name: String, + pub loc: Option, +} + +#[derive(Debug, Clone)] +pub enum JsxTag { + Place(Place), + Builtin(BuiltinTag), +} + +#[derive(Debug, Clone)] +pub enum JsxAttribute { + SpreadAttribute { argument: Place }, + Attribute { name: String, place: Place }, +} + +// ============================================================================= +// Variable Binding types +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BindingKind { + Var, + Let, + Const, + Param, + Module, + Hoisted, + Local, + Unknown, +} + +#[derive(Debug, Clone)] +pub enum VariableBinding { + Identifier { + identifier: IdentifierId, + binding_kind: BindingKind, + }, + Global { + name: String, + }, + ImportDefault { + name: String, + module: String, + }, + ImportSpecifier { + name: String, + module: String, + imported: String, + }, + ImportNamespace { + name: String, + module: String, + }, + ModuleLocal { + name: String, + }, +} + +#[derive(Debug, Clone)] +pub enum NonLocalBinding { + ImportDefault { + name: String, + module: String, + }, + ImportSpecifier { + name: String, + module: String, + imported: String, + }, + ImportNamespace { + name: String, + module: String, + }, + ModuleLocal { + name: String, + }, + Global { + name: String, + }, +} + +impl NonLocalBinding { + /// Returns the `name` field common to all variants. + pub fn name(&self) -> &str { + match self { + NonLocalBinding::ImportDefault { name, .. } + | NonLocalBinding::ImportSpecifier { name, .. } + | NonLocalBinding::ImportNamespace { name, .. } + | NonLocalBinding::ModuleLocal { name, .. } + | NonLocalBinding::Global { name, .. } => name, + } + } +} + +// ============================================================================= +// Type system (from Types.ts) +// ============================================================================= + +#[derive(Debug, Clone)] +pub enum Type { + Primitive, + Function { + shape_id: Option, + return_type: Box, + is_constructor: bool, + }, + Object { + shape_id: Option, + }, + TypeVar { + id: TypeId, + }, + Poly, + Phi { + operands: Vec, + }, + Property { + object_type: Box, + object_name: String, + property_name: PropertyNameKind, + }, + ObjectMethod, +} + +#[derive(Debug, Clone)] +pub enum PropertyNameKind { + Literal { value: PropertyLiteral }, + Computed { value: Box }, +} + +// ============================================================================= +// ReactiveScope +// ============================================================================= + +#[derive(Debug, Clone)] +pub struct ReactiveScope { + pub id: ScopeId, + pub range: MutableRange, + + /// The inputs to this reactive scope (populated by later passes) + pub dependencies: Vec, + + /// The set of values produced by this scope (populated by later passes) + pub declarations: Vec<(IdentifierId, ReactiveScopeDeclaration)>, + + /// Identifiers which are reassigned by this scope (populated by later + /// passes) + pub reassignments: Vec, + + /// If the scope contains an early return, this stores info about it + /// (populated by later passes) + pub early_return_value: Option, + + /// Scopes that were merged into this one (populated by later passes) + pub merged: Vec, + + /// Source location spanning the scope + pub loc: Option, +} + +/// A dependency of a reactive scope. +#[derive(Debug, Clone)] +pub struct ReactiveScopeDependency { + pub identifier: IdentifierId, + pub reactive: bool, + pub path: Vec, + pub loc: Option, +} + +/// A declaration produced by a reactive scope. +#[derive(Debug, Clone)] +pub struct ReactiveScopeDeclaration { + pub identifier: IdentifierId, + pub scope: ScopeId, +} + +/// Early return value info for a reactive scope. +#[derive(Debug, Clone)] +pub struct ReactiveScopeEarlyReturn { + pub value: IdentifierId, + pub loc: Option, + pub label: BlockId, +} + +// ============================================================================= +// Aliasing effects (runtime types, from AliasingEffects.ts) +// ============================================================================= + +use crate::{ + object_shape::FunctionSignature, + type_config::{ValueKind, ValueReason}, +}; + +/// Reason for a mutation, used for generating hints (e.g. rename to "Ref"). +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MutationReason { + AssignCurrentProperty, +} + +/// Describes the aliasing/mutation/data-flow effects of an instruction or +/// terminal. Ported from TS `AliasingEffect` in `AliasingEffects.ts`. +#[derive(Debug, Clone)] +pub enum AliasingEffect { + /// Marks the given value and its direct aliases as frozen. + Freeze { value: Place, reason: ValueReason }, + /// Mutate the value and any direct aliases. + Mutate { + value: Place, + reason: Option, + }, + /// Mutate the value conditionally (only if mutable). + MutateConditionally { value: Place }, + /// Mutate the value and transitive captures. + MutateTransitive { value: Place }, + /// Mutate the value and transitive captures conditionally. + MutateTransitiveConditionally { value: Place }, + /// Information flow from `from` to `into` (non-aliasing capture). + Capture { from: Place, into: Place }, + /// Direct aliasing: mutation of `into` implies mutation of `from`. + Alias { from: Place, into: Place }, + /// Potential aliasing relationship. + MaybeAlias { from: Place, into: Place }, + /// Direct assignment: `into = from`. + Assign { from: Place, into: Place }, + /// Creates a value of the given kind at the given place. + Create { + into: Place, + value: ValueKind, + reason: ValueReason, + }, + /// Creates a new value with the same kind as the source. + CreateFrom { from: Place, into: Place }, + /// Immutable data flow (escape analysis only, no mutable range influence). + ImmutableCapture { from: Place, into: Place }, + /// Function call application. + Apply { + receiver: Place, + function: Place, + mutates_function: bool, + args: Vec, + into: Place, + signature: Option, + loc: Option, + }, + /// Function expression creation with captures. + CreateFunction { + captures: Vec, + function_id: FunctionId, + into: Place, + }, + /// Mutation of a value known to be frozen (error). + MutateFrozen { + place: Place, + error: CompilerDiagnostic, + }, + /// Mutation of a global value (error). + MutateGlobal { + place: Place, + error: CompilerDiagnostic, + }, + /// Side-effect not safe during render. + Impure { + place: Place, + error: CompilerDiagnostic, + }, + /// Value is accessed during render. + Render { place: Place }, +} + +/// Combined Place/Spread/Hole for Apply args. +#[derive(Debug, Clone)] +pub enum PlaceOrSpreadOrHole { + Place(Place), + Spread(SpreadPattern), + Hole, +} + +/// Aliasing signature for function calls. +/// Ported from TS `AliasingSignature` in `AliasingEffects.ts`. +#[derive(Debug, Clone)] +pub struct AliasingSignature { + pub receiver: IdentifierId, + pub params: Vec, + pub rest: Option, + pub returns: IdentifierId, + pub effects: Vec, + pub temporaries: Vec, +} + +// ============================================================================= +// Type helper functions (ported from HIR.ts) +// ============================================================================= + +use crate::object_shape::{ + BUILT_IN_ARRAY_ID, BUILT_IN_JSX_ID, BUILT_IN_MAP_ID, BUILT_IN_PROPS_ID, BUILT_IN_REF_VALUE_ID, + BUILT_IN_SET_ID, BUILT_IN_USE_OPERATOR_ID, BUILT_IN_USE_REF_ID, +}; + +/// Returns true if the type (looked up via identifier) is primitive. +pub fn is_primitive_type(ty: &Type) -> bool { + matches!(ty, Type::Primitive) +} + +/// Returns true if the type is the props object. +pub fn is_props_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_PROPS_ID) +} + +/// Returns true if the type is an array. +pub fn is_array_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_ARRAY_ID) +} + +/// Returns true if the type is a Set. +pub fn is_set_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_SET_ID) +} + +/// Returns true if the type is a Map. +pub fn is_map_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_MAP_ID) +} + +/// Returns true if the type is JSX. +pub fn is_jsx_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_JSX_ID) +} + +/// Returns true if the identifier type is a ref value. +pub fn is_ref_value_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_REF_VALUE_ID) +} + +/// Returns true if the identifier type is useRef. +pub fn is_use_ref_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == BUILT_IN_USE_REF_ID) +} + +/// Returns true if the type is a ref or ref value. +pub fn is_ref_or_ref_value(ty: &Type) -> bool { + is_use_ref_type(ty) || is_ref_value_type(ty) +} + +/// Returns true if the type is a useState result (BuiltInUseState). +pub fn is_use_state_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == object_shape::BUILT_IN_USE_STATE_ID) +} + +/// Returns true if the type is a setState function (BuiltInSetState). +pub fn is_set_state_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_SET_STATE_ID) +} + +/// Returns true if the type is a useEffect hook. +pub fn is_use_effect_hook_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_USE_EFFECT_HOOK_ID) +} + +/// Returns true if the type is a useLayoutEffect hook. +pub fn is_use_layout_effect_hook_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_USE_LAYOUT_EFFECT_HOOK_ID) +} + +/// Returns true if the type is a useInsertionEffect hook. +pub fn is_use_insertion_effect_hook_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_USE_INSERTION_EFFECT_HOOK_ID) +} + +/// Returns true if the type is a useEffectEvent function. +pub fn is_use_effect_event_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_USE_EFFECT_EVENT_ID) +} + +/// Returns true if the type is a ref or ref-like mutable type (e.g. Reanimated +/// shared values). +pub fn is_ref_or_ref_like_mutable_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } + if id == object_shape::BUILT_IN_USE_REF_ID || id == object_shape::REANIMATED_SHARED_VALUE_ID) +} + +/// Returns true if the type is the `use()` operator (React.use). +pub fn is_use_operator_type(ty: &Type) -> bool { + matches!( + ty, + Type::Function { shape_id: Some(id), .. } + if id == BUILT_IN_USE_OPERATOR_ID + ) +} + +/// Returns true if the type is a plain object (BuiltInObject). +pub fn is_plain_object_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == object_shape::BUILT_IN_OBJECT_ID) +} + +/// Returns true if the type is a startTransition function +/// (BuiltInStartTransition). +pub fn is_start_transition_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_START_TRANSITION_ID) +} diff --git a/crates/react_compiler_hir/src/object_shape.rs b/crates/react_compiler_hir/src/object_shape.rs new file mode 100644 index 000000000000..3cbbecde36ee --- /dev/null +++ b/crates/react_compiler_hir/src/object_shape.rs @@ -0,0 +1,437 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Object shapes and function signatures, ported from ObjectShape.ts. +//! +//! Defines the shape registry used by Environment to resolve property types +//! and function call signatures for built-in objects, hooks, and user-defined +//! types. + +use std::collections::HashMap; + +use crate::{ + type_config::{AliasingEffectConfig, AliasingSignatureConfig, ValueKind, ValueReason}, + Effect, Type, +}; + +// ============================================================================= +// Shape ID constants (matching TS ObjectShape.ts) +// ============================================================================= + +pub const BUILT_IN_PROPS_ID: &str = "BuiltInProps"; +pub const BUILT_IN_ARRAY_ID: &str = "BuiltInArray"; +pub const BUILT_IN_SET_ID: &str = "BuiltInSet"; +pub const BUILT_IN_MAP_ID: &str = "BuiltInMap"; +pub const BUILT_IN_WEAK_SET_ID: &str = "BuiltInWeakSet"; +pub const BUILT_IN_WEAK_MAP_ID: &str = "BuiltInWeakMap"; +pub const BUILT_IN_FUNCTION_ID: &str = "BuiltInFunction"; +pub const BUILT_IN_JSX_ID: &str = "BuiltInJsx"; +pub const BUILT_IN_OBJECT_ID: &str = "BuiltInObject"; +pub const BUILT_IN_USE_STATE_ID: &str = "BuiltInUseState"; +pub const BUILT_IN_SET_STATE_ID: &str = "BuiltInSetState"; +pub const BUILT_IN_USE_ACTION_STATE_ID: &str = "BuiltInUseActionState"; +pub const BUILT_IN_SET_ACTION_STATE_ID: &str = "BuiltInSetActionState"; +pub const BUILT_IN_USE_REF_ID: &str = "BuiltInUseRefId"; +pub const BUILT_IN_REF_VALUE_ID: &str = "BuiltInRefValue"; +pub const BUILT_IN_MIXED_READONLY_ID: &str = "BuiltInMixedReadonly"; +pub const BUILT_IN_USE_EFFECT_HOOK_ID: &str = "BuiltInUseEffectHook"; +pub const BUILT_IN_USE_LAYOUT_EFFECT_HOOK_ID: &str = "BuiltInUseLayoutEffectHook"; +pub const BUILT_IN_USE_INSERTION_EFFECT_HOOK_ID: &str = "BuiltInUseInsertionEffectHook"; +pub const BUILT_IN_USE_OPERATOR_ID: &str = "BuiltInUseOperator"; +pub const BUILT_IN_USE_REDUCER_ID: &str = "BuiltInUseReducer"; +pub const BUILT_IN_DISPATCH_ID: &str = "BuiltInDispatch"; +pub const BUILT_IN_USE_CONTEXT_HOOK_ID: &str = "BuiltInUseContextHook"; +pub const BUILT_IN_USE_TRANSITION_ID: &str = "BuiltInUseTransition"; +pub const BUILT_IN_USE_OPTIMISTIC_ID: &str = "BuiltInUseOptimistic"; +pub const BUILT_IN_SET_OPTIMISTIC_ID: &str = "BuiltInSetOptimistic"; +pub const BUILT_IN_START_TRANSITION_ID: &str = "BuiltInStartTransition"; +pub const BUILT_IN_USE_EFFECT_EVENT_ID: &str = "BuiltInUseEffectEvent"; +pub const BUILT_IN_EFFECT_EVENT_ID: &str = "BuiltInEffectEventFunction"; +pub const REANIMATED_SHARED_VALUE_ID: &str = "ReanimatedSharedValueId"; + +// ============================================================================= +// Core types +// ============================================================================= + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum HookKind { + UseContext, + UseState, + UseActionState, + UseReducer, + UseRef, + UseEffect, + UseLayoutEffect, + UseInsertionEffect, + UseMemo, + UseCallback, + UseTransition, + UseImperativeHandle, + UseEffectEvent, + UseOptimistic, + Custom, +} + +impl std::fmt::Display for HookKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + HookKind::UseContext => write!(f, "useContext"), + HookKind::UseState => write!(f, "useState"), + HookKind::UseActionState => write!(f, "useActionState"), + HookKind::UseReducer => write!(f, "useReducer"), + HookKind::UseRef => write!(f, "useRef"), + HookKind::UseEffect => write!(f, "useEffect"), + HookKind::UseLayoutEffect => write!(f, "useLayoutEffect"), + HookKind::UseInsertionEffect => write!(f, "useInsertionEffect"), + HookKind::UseMemo => write!(f, "useMemo"), + HookKind::UseCallback => write!(f, "useCallback"), + HookKind::UseTransition => write!(f, "useTransition"), + HookKind::UseImperativeHandle => write!(f, "useImperativeHandle"), + HookKind::UseEffectEvent => write!(f, "useEffectEvent"), + HookKind::UseOptimistic => write!(f, "useOptimistic"), + HookKind::Custom => write!(f, "Custom"), + } + } +} + +/// Call signature of a function, used for type and effect inference. +/// Ported from TS `FunctionSignature`. +#[derive(Debug, Clone)] +pub struct FunctionSignature { + pub positional_params: Vec, + pub rest_param: Option, + pub return_type: Type, + pub return_value_kind: ValueKind, + pub return_value_reason: Option, + pub callee_effect: Effect, + pub hook_kind: Option, + pub no_alias: bool, + pub mutable_only_if_operands_are_mutable: bool, + pub impure: bool, + pub known_incompatible: Option, + pub canonical_name: Option, + /// Aliasing signature in config form. Full parsing into AliasingSignature + /// with Place values is deferred until the aliasing effects system is + /// ported. + pub aliasing: Option, +} + +/// Shape of an object or function type. +/// Ported from TS `ObjectShape`. +#[derive(Debug, Clone)] +pub struct ObjectShape { + pub properties: HashMap, + pub function_type: Option, +} + +/// Registry mapping shape IDs to their ObjectShape definitions. +/// +/// Supports two modes: +/// - **Builder mode** (`base=None`): wraps a single HashMap, used during +/// `build_builtin_shapes` / `build_default_globals` to construct the static +/// base. +/// - **Overlay mode** (`base=Some`): holds a `&'static HashMap` base plus a +/// small extras HashMap. Lookups check extras first, then base. Inserts go +/// into extras. Cloning only copies the extras map (the base pointer is +/// shared). +pub struct ShapeRegistry { + base: Option<&'static HashMap>, + entries: HashMap, +} + +impl ShapeRegistry { + /// Create an empty builder-mode registry. + pub fn new() -> Self { + Self { + base: None, + entries: HashMap::new(), + } + } + + /// Create an overlay-mode registry backed by a static base. + pub fn with_base(base: &'static HashMap) -> Self { + Self { + base: Some(base), + entries: HashMap::new(), + } + } + + pub fn get(&self, key: &str) -> Option<&ObjectShape> { + self.entries + .get(key) + .or_else(|| self.base.and_then(|b| b.get(key))) + } + + pub fn insert(&mut self, key: String, value: ObjectShape) { + self.entries.insert(key, value); + } + + /// Consume the registry and return the inner HashMap. + /// Only valid in builder mode (no base). + pub fn into_inner(self) -> HashMap { + debug_assert!( + self.base.is_none(), + "into_inner() called on overlay-mode ShapeRegistry" + ); + self.entries + } +} + +impl Clone for ShapeRegistry { + fn clone(&self) -> Self { + Self { + base: self.base, + entries: self.entries.clone(), + } + } +} + +// ============================================================================= +// Counter for anonymous shape IDs +// ============================================================================= + +/// Thread-local counter for generating unique anonymous shape IDs. +/// Mirrors TS `nextAnonId` in ObjectShape.ts. +fn next_anon_id() -> String { + use std::sync::atomic::{AtomicU32, Ordering}; + static COUNTER: AtomicU32 = AtomicU32::new(0); + let id = COUNTER.fetch_add(1, Ordering::Relaxed); + format!("", id) +} + +// ============================================================================= +// Builder functions (matching TS addFunction, addHook, addObject) +// ============================================================================= + +/// Add a non-hook function to a ShapeRegistry. +/// Returns a `Type::Function` representing the added function. +pub fn add_function( + registry: &mut ShapeRegistry, + properties: Vec<(String, Type)>, + sig: FunctionSignatureBuilder, + id: Option<&str>, + is_constructor: bool, +) -> Type { + let shape_id = id.map(|s| s.to_string()).unwrap_or_else(next_anon_id); + let return_type = sig.return_type.clone(); + add_shape( + registry, + &shape_id, + properties, + Some(FunctionSignature { + positional_params: sig.positional_params, + rest_param: sig.rest_param, + return_type: sig.return_type, + return_value_kind: sig.return_value_kind, + return_value_reason: sig.return_value_reason, + callee_effect: sig.callee_effect, + hook_kind: None, + no_alias: sig.no_alias, + mutable_only_if_operands_are_mutable: sig.mutable_only_if_operands_are_mutable, + impure: sig.impure, + known_incompatible: sig.known_incompatible, + canonical_name: sig.canonical_name, + aliasing: sig.aliasing, + }), + ); + Type::Function { + shape_id: Some(shape_id), + return_type: Box::new(return_type), + is_constructor, + } +} + +/// Add a hook to a ShapeRegistry. +/// Returns a `Type::Function` representing the added hook. +pub fn add_hook(registry: &mut ShapeRegistry, sig: HookSignatureBuilder, id: Option<&str>) -> Type { + let shape_id = id.map(|s| s.to_string()).unwrap_or_else(next_anon_id); + let return_type = sig.return_type.clone(); + add_shape( + registry, + &shape_id, + Vec::new(), + Some(FunctionSignature { + positional_params: sig.positional_params, + rest_param: sig.rest_param, + return_type: sig.return_type, + return_value_kind: sig.return_value_kind, + return_value_reason: sig.return_value_reason, + callee_effect: sig.callee_effect, + hook_kind: Some(sig.hook_kind), + no_alias: sig.no_alias, + mutable_only_if_operands_are_mutable: false, + impure: false, + known_incompatible: sig.known_incompatible, + canonical_name: None, + aliasing: sig.aliasing, + }), + ); + Type::Function { + shape_id: Some(shape_id), + return_type: Box::new(return_type), + is_constructor: false, + } +} + +/// Add an object to a ShapeRegistry. +/// Returns a `Type::Object` representing the added object. +pub fn add_object( + registry: &mut ShapeRegistry, + id: Option<&str>, + properties: Vec<(String, Type)>, +) -> Type { + let shape_id = id.map(|s| s.to_string()).unwrap_or_else(next_anon_id); + add_shape(registry, &shape_id, properties, None); + Type::Object { + shape_id: Some(shape_id), + } +} + +fn add_shape( + registry: &mut ShapeRegistry, + id: &str, + properties: Vec<(String, Type)>, + function_type: Option, +) { + let shape = ObjectShape { + properties: properties.into_iter().collect(), + function_type, + }; + // Note: TS has an invariant that the id doesn't already exist. We use + // insert which overwrites. In practice duplicates don't occur for built-in + // shapes, and for user configs we want last-write-wins behavior. + registry.insert(id.to_string(), shape); +} + +// ============================================================================= +// Builder structs (to avoid large parameter lists) +// ============================================================================= + +/// Builder for non-hook function signatures. +pub struct FunctionSignatureBuilder { + pub positional_params: Vec, + pub rest_param: Option, + pub return_type: Type, + pub return_value_kind: ValueKind, + pub return_value_reason: Option, + pub callee_effect: Effect, + pub no_alias: bool, + pub mutable_only_if_operands_are_mutable: bool, + pub impure: bool, + pub known_incompatible: Option, + pub canonical_name: Option, + pub aliasing: Option, +} + +impl Default for FunctionSignatureBuilder { + fn default() -> Self { + Self { + positional_params: Vec::new(), + rest_param: None, + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + return_value_reason: None, + callee_effect: Effect::Read, + no_alias: false, + mutable_only_if_operands_are_mutable: false, + impure: false, + known_incompatible: None, + canonical_name: None, + aliasing: None, + } + } +} + +/// Builder for hook signatures. +pub struct HookSignatureBuilder { + pub positional_params: Vec, + pub rest_param: Option, + pub return_type: Type, + pub return_value_kind: ValueKind, + pub return_value_reason: Option, + pub callee_effect: Effect, + pub hook_kind: HookKind, + pub no_alias: bool, + pub known_incompatible: Option, + pub aliasing: Option, +} + +impl Default for HookSignatureBuilder { + fn default() -> Self { + Self { + positional_params: Vec::new(), + rest_param: None, + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + return_value_reason: None, + callee_effect: Effect::Read, + hook_kind: HookKind::Custom, + no_alias: false, + known_incompatible: None, + aliasing: None, + } + } +} + +// ============================================================================= +// Default hook types used for unknown hooks +// ============================================================================= + +/// Default type for hooks when enableAssumeHooksFollowRulesOfReact is true. +/// Matches TS `DefaultNonmutatingHook`. +pub fn default_nonmutating_hook(registry: &mut ShapeRegistry) -> Type { + add_hook( + registry, + HookSignatureBuilder { + rest_param: Some(Effect::Freeze), + return_type: Type::Poly, + return_value_kind: ValueKind::Frozen, + hook_kind: HookKind::Custom, + aliasing: Some(AliasingSignatureConfig { + receiver: "@receiver".to_string(), + params: Vec::new(), + rest: Some("@rest".to_string()), + returns: "@returns".to_string(), + temporaries: Vec::new(), + effects: vec![ + // Freeze the arguments + AliasingEffectConfig::Freeze { + value: "@rest".to_string(), + reason: ValueReason::HookCaptured, + }, + // Returns a frozen value + AliasingEffectConfig::Create { + into: "@returns".to_string(), + value: ValueKind::Frozen, + reason: ValueReason::HookReturn, + }, + // May alias any arguments into the return + AliasingEffectConfig::Alias { + from: "@rest".to_string(), + into: "@returns".to_string(), + }, + ], + }), + ..Default::default() + }, + Some("DefaultNonmutatingHook"), + ) +} + +/// Default type for hooks when enableAssumeHooksFollowRulesOfReact is false. +/// Matches TS `DefaultMutatingHook`. +pub fn default_mutating_hook(registry: &mut ShapeRegistry) -> Type { + add_hook( + registry, + HookSignatureBuilder { + rest_param: Some(Effect::ConditionallyMutate), + return_type: Type::Poly, + return_value_kind: ValueKind::Mutable, + hook_kind: HookKind::Custom, + ..Default::default() + }, + Some("DefaultMutatingHook"), + ) +} diff --git a/crates/react_compiler_hir/src/print.rs b/crates/react_compiler_hir/src/print.rs new file mode 100644 index 000000000000..b4d218e43ec6 --- /dev/null +++ b/crates/react_compiler_hir/src/print.rs @@ -0,0 +1,1489 @@ +//! Shared formatting utilities for HIR debug printing. +//! +//! This module provides `PrintFormatter` — a stateful formatter that both +//! `react_compiler::debug_print` (HIR printer) and +//! `react_compiler_reactive_scopes::print_reactive_function` (reactive printer) +//! delegate to for shared formatting logic. +//! +//! It also exports standalone formatting functions (format_loc, +//! format_primitive, etc.) that require no state. + +use std::collections::HashSet; + +use react_compiler_diagnostics::{CompilerError, CompilerErrorOrDiagnostic, SourceLocation}; + +use crate::{ + environment::Environment, + type_config::{ValueKind, ValueReason}, + AliasingEffect, HirFunction, IdentifierId, IdentifierName, InstructionValue, LValue, + MutationReason, Pattern, Place, PlaceOrSpreadOrHole, ScopeId, Type, +}; + +// ============================================================================= +// Standalone formatting functions (no state needed) +// ============================================================================= + +pub fn format_loc(loc: &Option) -> String { + match loc { + Some(l) => format_loc_value(l), + None => "generated".to_string(), + } +} + +pub fn format_loc_value(loc: &SourceLocation) -> String { + format!( + "{}:{}-{}:{}", + loc.start.line, loc.start.column, loc.end.line, loc.end.column + ) +} + +pub fn format_primitive(prim: &crate::PrimitiveValue) -> String { + match prim { + crate::PrimitiveValue::Null => "null".to_string(), + crate::PrimitiveValue::Undefined => "undefined".to_string(), + crate::PrimitiveValue::Boolean(b) => format!("{}", b), + crate::PrimitiveValue::Number(n) => { + let v = n.value(); + // Match JS String(-0) === "0" behavior + if v == 0.0 && v.is_sign_negative() { + "0".to_string() + } else { + format!("{}", v) + } + } + crate::PrimitiveValue::String(s) => { + // Format like JS JSON.stringify: escape control chars and quotes but NOT + // non-ASCII unicode + let mut result = String::with_capacity(s.len() + 2); + result.push('"'); + for c in s.chars() { + match c { + '"' => result.push_str("\\\""), + '\\' => result.push_str("\\\\"), + '\n' => result.push_str("\\n"), + '\r' => result.push_str("\\r"), + '\t' => result.push_str("\\t"), + c if c.is_control() => { + result.push_str(&format!("\\u{{{:04x}}}", c as u32)); + } + c => result.push(c), + } + } + result.push('"'); + result + } + } +} + +pub fn format_property_literal(prop: &crate::PropertyLiteral) -> String { + match prop { + crate::PropertyLiteral::String(s) => s.clone(), + crate::PropertyLiteral::Number(n) => format!("{}", n.value()), + } +} + +pub fn format_object_property_key(key: &crate::ObjectPropertyKey) -> String { + match key { + crate::ObjectPropertyKey::String { name } => format!("String(\"{}\")", name), + crate::ObjectPropertyKey::Identifier { name } => { + format!("Identifier(\"{}\")", name) + } + crate::ObjectPropertyKey::Computed { name } => { + format!("Computed({})", name.identifier.0) + } + crate::ObjectPropertyKey::Number { name } => { + format!("Number({})", name.value()) + } + } +} + +pub fn format_non_local_binding(binding: &crate::NonLocalBinding) -> String { + match binding { + crate::NonLocalBinding::Global { name } => { + format!("Global {{ name: \"{}\" }}", name) + } + crate::NonLocalBinding::ModuleLocal { name } => { + format!("ModuleLocal {{ name: \"{}\" }}", name) + } + crate::NonLocalBinding::ImportDefault { name, module } => { + format!( + "ImportDefault {{ name: \"{}\", module: \"{}\" }}", + name, module + ) + } + crate::NonLocalBinding::ImportNamespace { name, module } => { + format!( + "ImportNamespace {{ name: \"{}\", module: \"{}\" }}", + name, module + ) + } + crate::NonLocalBinding::ImportSpecifier { + name, + module, + imported, + } => { + format!( + "ImportSpecifier {{ name: \"{}\", module: \"{}\", imported: \"{}\" }}", + name, module, imported + ) + } + } +} + +pub fn format_value_kind(kind: ValueKind) -> &'static str { + match kind { + ValueKind::Mutable => "mutable", + ValueKind::Frozen => "frozen", + ValueKind::Primitive => "primitive", + ValueKind::MaybeFrozen => "maybe-frozen", + ValueKind::Global => "global", + ValueKind::Context => "context", + } +} + +pub fn format_value_reason(reason: ValueReason) -> &'static str { + match reason { + ValueReason::KnownReturnSignature => "known-return-signature", + ValueReason::State => "state", + ValueReason::ReducerState => "reducer-state", + ValueReason::Context => "context", + ValueReason::Effect => "effect", + ValueReason::HookCaptured => "hook-captured", + ValueReason::HookReturn => "hook-return", + ValueReason::Global => "global", + ValueReason::JsxCaptured => "jsx-captured", + ValueReason::StoreLocal => "store-local", + ValueReason::ReactiveFunctionArgument => "reactive-function-argument", + ValueReason::Other => "other", + } +} + +// ============================================================================= +// PrintFormatter — shared stateful formatter +// ============================================================================= + +/// Shared formatter state used by both HIR and reactive printers. +/// +/// Both `DebugPrinter` structs delegate to this for formatting shared +/// constructs like Places, Identifiers, Scopes, Types, InstructionValues, etc. +pub struct PrintFormatter<'a> { + pub env: &'a Environment, + pub seen_identifiers: HashSet, + pub seen_scopes: HashSet, + pub output: Vec, + pub indent_level: usize, +} + +impl<'a> PrintFormatter<'a> { + pub fn new(env: &'a Environment) -> Self { + Self { + env, + seen_identifiers: HashSet::new(), + seen_scopes: HashSet::new(), + output: Vec::new(), + indent_level: 0, + } + } + + pub fn line(&mut self, text: &str) { + let indent = " ".repeat(self.indent_level); + self.output.push(format!("{}{}", indent, text)); + } + + /// Write a line without adding indentation (used when copying pre-formatted + /// output) + pub fn line_raw(&mut self, text: &str) { + self.output.push(text.to_string()); + } + + pub fn indent(&mut self) { + self.indent_level += 1; + } + + pub fn dedent(&mut self) { + self.indent_level -= 1; + } + + pub fn to_string_output(&self) -> String { + self.output.join("\n") + } + + // ========================================================================= + // AliasingEffect + // ========================================================================= + + pub fn format_effect(&self, effect: &AliasingEffect) -> String { + match effect { + AliasingEffect::Freeze { value, reason } => { + format!( + "Freeze {{ value: {}, reason: {} }}", + value.identifier.0, + format_value_reason(*reason) + ) + } + AliasingEffect::Mutate { value, reason } => match reason { + Some(MutationReason::AssignCurrentProperty) => { + format!( + "Mutate {{ value: {}, reason: AssignCurrentProperty }}", + value.identifier.0 + ) + } + None => format!("Mutate {{ value: {} }}", value.identifier.0), + }, + AliasingEffect::MutateConditionally { value } => { + format!("MutateConditionally {{ value: {} }}", value.identifier.0) + } + AliasingEffect::MutateTransitive { value } => { + format!("MutateTransitive {{ value: {} }}", value.identifier.0) + } + AliasingEffect::MutateTransitiveConditionally { value } => { + format!( + "MutateTransitiveConditionally {{ value: {} }}", + value.identifier.0 + ) + } + AliasingEffect::Capture { from, into } => { + format!( + "Capture {{ into: {}, from: {} }}", + into.identifier.0, from.identifier.0 + ) + } + AliasingEffect::Alias { from, into } => { + format!( + "Alias {{ into: {}, from: {} }}", + into.identifier.0, from.identifier.0 + ) + } + AliasingEffect::MaybeAlias { from, into } => { + format!( + "MaybeAlias {{ into: {}, from: {} }}", + into.identifier.0, from.identifier.0 + ) + } + AliasingEffect::Assign { from, into } => { + format!( + "Assign {{ into: {}, from: {} }}", + into.identifier.0, from.identifier.0 + ) + } + AliasingEffect::Create { + into, + value, + reason, + } => { + format!( + "Create {{ into: {}, value: {}, reason: {} }}", + into.identifier.0, + format_value_kind(*value), + format_value_reason(*reason) + ) + } + AliasingEffect::CreateFrom { from, into } => { + format!( + "CreateFrom {{ into: {}, from: {} }}", + into.identifier.0, from.identifier.0 + ) + } + AliasingEffect::ImmutableCapture { from, into } => { + format!( + "ImmutableCapture {{ into: {}, from: {} }}", + into.identifier.0, from.identifier.0 + ) + } + AliasingEffect::Apply { + receiver, + function, + mutates_function, + args, + into, + .. + } => { + let args_str: Vec = args + .iter() + .map(|a| match a { + PlaceOrSpreadOrHole::Hole => "hole".to_string(), + PlaceOrSpreadOrHole::Place(p) => p.identifier.0.to_string(), + PlaceOrSpreadOrHole::Spread(s) => format!("...{}", s.place.identifier.0), + }) + .collect(); + format!( + "Apply {{ into: {}, receiver: {}, function: {}, mutatesFunction: {}, args: \ + [{}] }}", + into.identifier.0, + receiver.identifier.0, + function.identifier.0, + mutates_function, + args_str.join(", ") + ) + } + AliasingEffect::CreateFunction { + captures, + function_id: _, + into, + } => { + let cap_str: Vec = captures + .iter() + .map(|p| p.identifier.0.to_string()) + .collect(); + format!( + "CreateFunction {{ into: {}, captures: [{}] }}", + into.identifier.0, + cap_str.join(", ") + ) + } + AliasingEffect::MutateFrozen { place, error } => { + format!( + "MutateFrozen {{ place: {}, reason: {:?} }}", + place.identifier.0, error.reason + ) + } + AliasingEffect::MutateGlobal { place, error } => { + format!( + "MutateGlobal {{ place: {}, reason: {:?} }}", + place.identifier.0, error.reason + ) + } + AliasingEffect::Impure { place, error } => { + format!( + "Impure {{ place: {}, reason: {:?} }}", + place.identifier.0, error.reason + ) + } + AliasingEffect::Render { place } => { + format!("Render {{ place: {} }}", place.identifier.0) + } + } + } + + // ========================================================================= + // Place (with identifier deduplication) + // ========================================================================= + + pub fn format_place_field(&mut self, field_name: &str, place: &Place) { + let is_seen = self.seen_identifiers.contains(&place.identifier); + if is_seen { + self.line(&format!( + "{}: Place {{ identifier: Identifier({}), effect: {}, reactive: {}, loc: {} }}", + field_name, + place.identifier.0, + place.effect, + place.reactive, + format_loc(&place.loc) + )); + } else { + self.line(&format!("{}: Place {{", field_name)); + self.indent(); + self.line("identifier:"); + self.indent(); + self.format_identifier(place.identifier); + self.dedent(); + self.line(&format!("effect: {}", place.effect)); + self.line(&format!("reactive: {}", place.reactive)); + self.line(&format!("loc: {}", format_loc(&place.loc))); + self.dedent(); + self.line("}"); + } + } + + // ========================================================================= + // Identifier (first-seen expansion) + // ========================================================================= + + pub fn format_identifier(&mut self, id: IdentifierId) { + self.seen_identifiers.insert(id); + let ident = &self.env.identifiers[id.0 as usize]; + self.line("Identifier {"); + self.indent(); + self.line(&format!("id: {}", ident.id.0)); + self.line(&format!("declarationId: {}", ident.declaration_id.0)); + match &ident.name { + Some(name) => { + let (kind, value) = match name { + IdentifierName::Named(n) => ("named", n.as_str()), + IdentifierName::Promoted(n) => ("promoted", n.as_str()), + }; + self.line(&format!( + "name: {{ kind: \"{}\", value: \"{}\" }}", + kind, value + )); + } + None => self.line("name: null"), + } + // Print the identifier's mutable_range directly, matching the TS + // DebugPrintHIR which prints `identifier.mutableRange`. In TS, + // InferReactiveScopeVariables sets identifier.mutableRange = scope.range + // (shared reference), and AlignReactiveScopesToBlockScopesHIR syncs them. + // After MergeOverlappingReactiveScopesHIR repoints scopes, the TS + // identifier.mutableRange still references the OLD scope's range (stale), + // so we match by using ident.mutable_range directly (which is synced + // at the AlignReactiveScopesToBlockScopesHIR step but not re-synced + // after scope repointing in merge passes). + self.line(&format!( + "mutableRange: [{}:{}]", + ident.mutable_range.start.0, ident.mutable_range.end.0 + )); + match ident.scope { + Some(scope_id) => self.format_scope_field("scope", scope_id), + None => self.line("scope: null"), + } + self.line(&format!("type: {}", self.format_type(ident.type_))); + self.line(&format!("loc: {}", format_loc(&ident.loc))); + self.dedent(); + self.line("}"); + } + + // ========================================================================= + // Scope (with deduplication) + // ========================================================================= + + pub fn format_scope_field(&mut self, field_name: &str, scope_id: ScopeId) { + let is_seen = self.seen_scopes.contains(&scope_id); + if is_seen { + self.line(&format!("{}: Scope({})", field_name, scope_id.0)); + } else { + self.seen_scopes.insert(scope_id); + if let Some(scope) = self.env.scopes.iter().find(|s| s.id == scope_id) { + let range_start = scope.range.start.0; + let range_end = scope.range.end.0; + let dependencies = scope.dependencies.clone(); + let declarations = scope.declarations.clone(); + let reassignments = scope.reassignments.clone(); + let early_return_value = scope.early_return_value.clone(); + let merged = scope.merged.clone(); + let loc = scope.loc; + + self.line(&format!("{}: Scope {{", field_name)); + self.indent(); + self.line(&format!("id: {}", scope_id.0)); + self.line(&format!("range: [{}:{}]", range_start, range_end)); + + // dependencies + self.line("dependencies:"); + self.indent(); + for (i, dep) in dependencies.iter().enumerate() { + let path_str: String = dep + .path + .iter() + .map(|p| { + let prop = match &p.property { + crate::PropertyLiteral::String(s) => s.clone(), + crate::PropertyLiteral::Number(n) => { + format!("{}", n.value()) + } + }; + format!("{}{}", if p.optional { "?." } else { "." }, prop) + }) + .collect(); + self.line(&format!( + "[{}] {{ identifier: {}, reactive: {}, path: \"{}\" }}", + i, dep.identifier.0, dep.reactive, path_str + )); + } + self.dedent(); + + // declarations + self.line("declarations:"); + self.indent(); + for (ident_id, decl) in &declarations { + self.line(&format!( + "{}: {{ identifier: {}, scope: {} }}", + ident_id.0, decl.identifier.0, decl.scope.0 + )); + } + self.dedent(); + + // reassignments + self.line("reassignments:"); + self.indent(); + for ident_id in &reassignments { + self.line(&format!("{}", ident_id.0)); + } + self.dedent(); + + // earlyReturnValue + if let Some(early_return) = &early_return_value { + self.line("earlyReturnValue:"); + self.indent(); + self.line(&format!("value: {}", early_return.value.0)); + self.line(&format!("loc: {}", format_loc(&early_return.loc))); + self.line(&format!("label: bb{}", early_return.label.0)); + self.dedent(); + } else { + self.line("earlyReturnValue: null"); + } + + // merged + let merged_str: Vec = merged.iter().map(|s| s.0.to_string()).collect(); + self.line(&format!("merged: [{}]", merged_str.join(", "))); + + // loc + self.line(&format!("loc: {}", format_loc(&loc))); + + self.dedent(); + self.line("}"); + } else { + self.line(&format!("{}: Scope({})", field_name, scope_id.0)); + } + } + } + + // ========================================================================= + // Type + // ========================================================================= + + pub fn format_type(&self, type_id: crate::TypeId) -> String { + if let Some(ty) = self.env.types.get(type_id.0 as usize) { + self.format_type_value(ty) + } else { + format!("Type({})", type_id.0) + } + } + + pub fn format_type_value(&self, ty: &Type) -> String { + match ty { + Type::Primitive => "Primitive".to_string(), + Type::Function { + shape_id, + return_type, + is_constructor, + } => { + format!( + "Function {{ shapeId: {}, return: {}, isConstructor: {} }}", + match shape_id { + Some(s) => format!("\"{}\"", s), + None => "null".to_string(), + }, + self.format_type_value(return_type), + is_constructor + ) + } + Type::Object { shape_id } => { + format!( + "Object {{ shapeId: {} }}", + match shape_id { + Some(s) => format!("\"{}\"", s), + None => "null".to_string(), + } + ) + } + Type::TypeVar { id } => format!("Type({})", id.0), + Type::Poly => "Poly".to_string(), + Type::Phi { operands } => { + let ops: Vec = operands + .iter() + .map(|op| self.format_type_value(op)) + .collect(); + format!("Phi {{ operands: [{}] }}", ops.join(", ")) + } + Type::Property { + object_type, + object_name, + property_name, + } => { + let prop_str = match property_name { + crate::PropertyNameKind::Literal { value } => { + format!("\"{}\"", format_property_literal(value)) + } + crate::PropertyNameKind::Computed { value } => { + format!("computed({})", self.format_type_value(value)) + } + }; + format!( + "Property {{ objectType: {}, objectName: \"{}\", propertyName: {} }}", + self.format_type_value(object_type), + object_name, + prop_str + ) + } + Type::ObjectMethod => "ObjectMethod".to_string(), + } + } + + // ========================================================================= + // LValue + // ========================================================================= + + pub fn format_lvalue(&mut self, field_name: &str, lv: &LValue) { + self.line(&format!("{}:", field_name)); + self.indent(); + self.line(&format!("kind: {:?}", lv.kind)); + self.format_place_field("place", &lv.place); + self.dedent(); + } + + // ========================================================================= + // Pattern + // ========================================================================= + + pub fn format_pattern(&mut self, pattern: &Pattern) { + match pattern { + Pattern::Array(arr) => { + self.line("pattern: ArrayPattern {"); + self.indent(); + self.line("items:"); + self.indent(); + for (i, item) in arr.items.iter().enumerate() { + match item { + crate::ArrayPatternElement::Hole => { + self.line(&format!("[{}] Hole", i)); + } + crate::ArrayPatternElement::Place(p) => { + self.format_place_field(&format!("[{}]", i), p); + } + crate::ArrayPatternElement::Spread(s) => { + self.line(&format!("[{}] Spread:", i)); + self.indent(); + self.format_place_field("place", &s.place); + self.dedent(); + } + } + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(&arr.loc))); + self.dedent(); + self.line("}"); + } + Pattern::Object(obj) => { + self.line("pattern: ObjectPattern {"); + self.indent(); + self.line("properties:"); + self.indent(); + for (i, prop) in obj.properties.iter().enumerate() { + match prop { + crate::ObjectPropertyOrSpread::Property(p) => { + self.line(&format!("[{}] ObjectProperty {{", i)); + self.indent(); + self.line(&format!("key: {}", format_object_property_key(&p.key))); + self.line(&format!("type: \"{}\"", p.property_type)); + self.format_place_field("place", &p.place); + self.dedent(); + self.line("}"); + } + crate::ObjectPropertyOrSpread::Spread(s) => { + self.line(&format!("[{}] Spread:", i)); + self.indent(); + self.format_place_field("place", &s.place); + self.dedent(); + } + } + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(&obj.loc))); + self.dedent(); + self.line("}"); + } + } + } + + // ========================================================================= + // Arguments + // ========================================================================= + + pub fn format_argument(&mut self, arg: &crate::PlaceOrSpread, index: usize) { + match arg { + crate::PlaceOrSpread::Place(p) => { + self.format_place_field(&format!("[{}]", index), p); + } + crate::PlaceOrSpread::Spread(s) => { + self.line(&format!("[{}] Spread:", index)); + self.indent(); + self.format_place_field("place", &s.place); + self.dedent(); + } + } + } + + // ========================================================================= + // InstructionValue + // ========================================================================= + + /// Format an InstructionValue. The `inner_func_formatter` callback is + /// invoked for FunctionExpression/ObjectMethod to format the inner + /// HirFunction. If None, a placeholder is printed instead. + pub fn format_instruction_value( + &mut self, + value: &InstructionValue, + inner_func_formatter: Option<&dyn Fn(&mut PrintFormatter, &HirFunction)>, + ) { + match value { + InstructionValue::ArrayExpression { elements, loc } => { + self.line("ArrayExpression {"); + self.indent(); + self.line("elements:"); + self.indent(); + for (i, elem) in elements.iter().enumerate() { + match elem { + crate::ArrayElement::Place(p) => { + self.format_place_field(&format!("[{}]", i), p); + } + crate::ArrayElement::Hole => { + self.line(&format!("[{}] Hole", i)); + } + crate::ArrayElement::Spread(s) => { + self.line(&format!("[{}] Spread:", i)); + self.indent(); + self.format_place_field("place", &s.place); + self.dedent(); + } + } + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::ObjectExpression { properties, loc } => { + self.line("ObjectExpression {"); + self.indent(); + self.line("properties:"); + self.indent(); + for (i, prop) in properties.iter().enumerate() { + match prop { + crate::ObjectPropertyOrSpread::Property(p) => { + self.line(&format!("[{}] ObjectProperty {{", i)); + self.indent(); + self.line(&format!("key: {}", format_object_property_key(&p.key))); + self.line(&format!("type: \"{}\"", p.property_type)); + self.format_place_field("place", &p.place); + self.dedent(); + self.line("}"); + } + crate::ObjectPropertyOrSpread::Spread(s) => { + self.line(&format!("[{}] Spread:", i)); + self.indent(); + self.format_place_field("place", &s.place); + self.dedent(); + } + } + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::UnaryExpression { + operator, + value: val, + loc, + } => { + self.line("UnaryExpression {"); + self.indent(); + self.line(&format!("operator: \"{}\"", operator)); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::BinaryExpression { + operator, + left, + right, + loc, + } => { + self.line("BinaryExpression {"); + self.indent(); + self.line(&format!("operator: \"{}\"", operator)); + self.format_place_field("left", left); + self.format_place_field("right", right); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::NewExpression { callee, args, loc } => { + self.line("NewExpression {"); + self.indent(); + self.format_place_field("callee", callee); + self.line("args:"); + self.indent(); + for (i, arg) in args.iter().enumerate() { + self.format_argument(arg, i); + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::CallExpression { callee, args, loc } => { + self.line("CallExpression {"); + self.indent(); + self.format_place_field("callee", callee); + self.line("args:"); + self.indent(); + for (i, arg) in args.iter().enumerate() { + self.format_argument(arg, i); + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::MethodCall { + receiver, + property, + args, + loc, + } => { + self.line("MethodCall {"); + self.indent(); + self.format_place_field("receiver", receiver); + self.format_place_field("property", property); + self.line("args:"); + self.indent(); + for (i, arg) in args.iter().enumerate() { + self.format_argument(arg, i); + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::JSXText { value: val, loc } => { + self.line(&format!( + "JSXText {{ value: {:?}, loc: {} }}", + val, + format_loc(loc) + )); + } + InstructionValue::Primitive { value: prim, loc } => { + self.line(&format!( + "Primitive {{ value: {}, loc: {} }}", + format_primitive(prim), + format_loc(loc) + )); + } + InstructionValue::TypeCastExpression { + value: val, + type_, + type_annotation_name, + type_annotation_kind, + type_annotation: _, + loc, + } => { + self.line("TypeCastExpression {"); + self.indent(); + self.format_place_field("value", val); + self.line(&format!("type: {}", self.format_type_value(type_))); + if let Some(annotation_name) = type_annotation_name { + self.line(&format!("typeAnnotation: {}", annotation_name)); + } + if let Some(annotation_kind) = type_annotation_kind { + self.line(&format!("typeAnnotationKind: \"{}\"", annotation_kind)); + } + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::JsxExpression { + tag, + props, + children, + loc, + opening_loc, + closing_loc, + } => { + self.line("JsxExpression {"); + self.indent(); + match tag { + crate::JsxTag::Place(p) => { + self.format_place_field("tag", p); + } + crate::JsxTag::Builtin(b) => { + self.line(&format!("tag: BuiltinTag(\"{}\")", b.name)); + } + } + self.line("props:"); + self.indent(); + for (i, prop) in props.iter().enumerate() { + match prop { + crate::JsxAttribute::Attribute { name, place } => { + self.line(&format!("[{}] JsxAttribute {{", i)); + self.indent(); + self.line(&format!("name: \"{}\"", name)); + self.format_place_field("place", place); + self.dedent(); + self.line("}"); + } + crate::JsxAttribute::SpreadAttribute { argument } => { + self.line(&format!("[{}] JsxSpreadAttribute:", i)); + self.indent(); + self.format_place_field("argument", argument); + self.dedent(); + } + } + } + self.dedent(); + match children { + Some(c) => { + self.line("children:"); + self.indent(); + for (i, child) in c.iter().enumerate() { + self.format_place_field(&format!("[{}]", i), child); + } + self.dedent(); + } + None => self.line("children: null"), + } + self.line(&format!("openingLoc: {}", format_loc(opening_loc))); + self.line(&format!("closingLoc: {}", format_loc(closing_loc))); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::JsxFragment { children, loc } => { + self.line("JsxFragment {"); + self.indent(); + self.line("children:"); + self.indent(); + for (i, child) in children.iter().enumerate() { + self.format_place_field(&format!("[{}]", i), child); + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::UnsupportedNode { node_type, loc, .. } => match node_type { + Some(t) => self.line(&format!( + "UnsupportedNode {{ type: {:?}, loc: {} }}", + t, + format_loc(loc) + )), + None => self.line(&format!("UnsupportedNode {{ loc: {} }}", format_loc(loc))), + }, + InstructionValue::LoadLocal { place, loc } => { + self.line("LoadLocal {"); + self.indent(); + self.format_place_field("place", place); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::DeclareLocal { + lvalue, + type_annotation, + loc, + } => { + self.line("DeclareLocal {"); + self.indent(); + self.format_lvalue("lvalue", lvalue); + self.line(&format!( + "type: {}", + match type_annotation { + Some(t) => t.clone(), + None => "null".to_string(), + } + )); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::DeclareContext { lvalue, loc } => { + self.line("DeclareContext {"); + self.indent(); + self.line("lvalue:"); + self.indent(); + self.line(&format!("kind: {:?}", lvalue.kind)); + self.format_place_field("place", &lvalue.place); + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::StoreLocal { + lvalue, + value: val, + type_annotation, + loc, + } => { + self.line("StoreLocal {"); + self.indent(); + self.format_lvalue("lvalue", lvalue); + self.format_place_field("value", val); + self.line(&format!( + "type: {}", + match type_annotation { + Some(t) => t.clone(), + None => "null".to_string(), + } + )); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::LoadContext { place, loc } => { + self.line("LoadContext {"); + self.indent(); + self.format_place_field("place", place); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::StoreContext { + lvalue, + value: val, + loc, + } => { + self.line("StoreContext {"); + self.indent(); + self.line("lvalue:"); + self.indent(); + self.line(&format!("kind: {:?}", lvalue.kind)); + self.format_place_field("place", &lvalue.place); + self.dedent(); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::Destructure { + lvalue, + value: val, + loc, + } => { + self.line("Destructure {"); + self.indent(); + self.line("lvalue:"); + self.indent(); + self.line(&format!("kind: {:?}", lvalue.kind)); + self.format_pattern(&lvalue.pattern); + self.dedent(); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::PropertyLoad { + object, + property, + loc, + } => { + self.line("PropertyLoad {"); + self.indent(); + self.format_place_field("object", object); + self.line(&format!( + "property: \"{}\"", + format_property_literal(property) + )); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::PropertyStore { + object, + property, + value: val, + loc, + } => { + self.line("PropertyStore {"); + self.indent(); + self.format_place_field("object", object); + self.line(&format!( + "property: \"{}\"", + format_property_literal(property) + )); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::PropertyDelete { + object, + property, + loc, + } => { + self.line("PropertyDelete {"); + self.indent(); + self.format_place_field("object", object); + self.line(&format!( + "property: \"{}\"", + format_property_literal(property) + )); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::ComputedLoad { + object, + property, + loc, + } => { + self.line("ComputedLoad {"); + self.indent(); + self.format_place_field("object", object); + self.format_place_field("property", property); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::ComputedStore { + object, + property, + value: val, + loc, + } => { + self.line("ComputedStore {"); + self.indent(); + self.format_place_field("object", object); + self.format_place_field("property", property); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::ComputedDelete { + object, + property, + loc, + } => { + self.line("ComputedDelete {"); + self.indent(); + self.format_place_field("object", object); + self.format_place_field("property", property); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::LoadGlobal { binding, loc } => { + self.line("LoadGlobal {"); + self.indent(); + self.line(&format!("binding: {}", format_non_local_binding(binding))); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::StoreGlobal { + name, + value: val, + loc, + } => { + self.line("StoreGlobal {"); + self.indent(); + self.line(&format!("name: \"{}\"", name)); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::FunctionExpression { + name, + name_hint, + lowered_func, + expr_type, + loc, + } => { + self.line("FunctionExpression {"); + self.indent(); + self.line(&format!( + "name: {}", + match name { + Some(n) => format!("\"{}\"", n), + None => "null".to_string(), + } + )); + self.line(&format!( + "nameHint: {}", + match name_hint { + Some(h) => format!("\"{}\"", h), + None => "null".to_string(), + } + )); + self.line(&format!("type: \"{:?}\"", expr_type)); + self.line("loweredFunc:"); + let inner_func = &self.env.functions[lowered_func.func.0 as usize]; + if let Some(formatter) = inner_func_formatter { + formatter(self, inner_func); + } else { + self.line(&format!(" ", lowered_func.func.0)); + } + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::ObjectMethod { loc, lowered_func } => { + self.line("ObjectMethod {"); + self.indent(); + self.line("loweredFunc:"); + let inner_func = &self.env.functions[lowered_func.func.0 as usize]; + if let Some(formatter) = inner_func_formatter { + formatter(self, inner_func); + } else { + self.line(&format!(" ", lowered_func.func.0)); + } + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::TaggedTemplateExpression { + tag, + value: val, + loc, + } => { + self.line("TaggedTemplateExpression {"); + self.indent(); + self.format_place_field("tag", tag); + self.line(&format!("raw: {:?}", val.raw)); + self.line(&format!( + "cooked: {}", + match &val.cooked { + Some(c) => format!("{:?}", c), + None => "undefined".to_string(), + } + )); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::TemplateLiteral { + subexprs, + quasis, + loc, + } => { + self.line("TemplateLiteral {"); + self.indent(); + self.line("subexprs:"); + self.indent(); + for (i, sub) in subexprs.iter().enumerate() { + self.format_place_field(&format!("[{}]", i), sub); + } + self.dedent(); + self.line("quasis:"); + self.indent(); + for (i, q) in quasis.iter().enumerate() { + self.line(&format!( + "[{}] {{ raw: {:?}, cooked: {} }}", + i, + q.raw, + match &q.cooked { + Some(c) => format!("{:?}", c), + None => "undefined".to_string(), + } + )); + } + self.dedent(); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::RegExpLiteral { + pattern, + flags, + loc, + } => { + self.line(&format!( + "RegExpLiteral {{ pattern: \"{}\", flags: \"{}\", loc: {} }}", + pattern, + flags, + format_loc(loc) + )); + } + InstructionValue::MetaProperty { + meta, + property, + loc, + } => { + self.line(&format!( + "MetaProperty {{ meta: \"{}\", property: \"{}\", loc: {} }}", + meta, + property, + format_loc(loc) + )); + } + InstructionValue::Await { value: val, loc } => { + self.line("Await {"); + self.indent(); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::GetIterator { collection, loc } => { + self.line("GetIterator {"); + self.indent(); + self.format_place_field("collection", collection); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::IteratorNext { + iterator, + collection, + loc, + } => { + self.line("IteratorNext {"); + self.indent(); + self.format_place_field("iterator", iterator); + self.format_place_field("collection", collection); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::NextPropertyOf { value: val, loc } => { + self.line("NextPropertyOf {"); + self.indent(); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::Debugger { loc } => { + self.line(&format!("Debugger {{ loc: {} }}", format_loc(loc))); + } + InstructionValue::PostfixUpdate { + lvalue, + operation, + value: val, + loc, + } => { + self.line("PostfixUpdate {"); + self.indent(); + self.format_place_field("lvalue", lvalue); + self.line(&format!("operation: \"{}\"", operation)); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::PrefixUpdate { + lvalue, + operation, + value: val, + loc, + } => { + self.line("PrefixUpdate {"); + self.indent(); + self.format_place_field("lvalue", lvalue); + self.line(&format!("operation: \"{}\"", operation)); + self.format_place_field("value", val); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::StartMemoize { + manual_memo_id, + deps, + deps_loc: _, + has_invalid_deps: _, + loc, + } => { + self.line("StartMemoize {"); + self.indent(); + self.line(&format!("manualMemoId: {}", manual_memo_id)); + match deps { + Some(d) => { + self.line("deps:"); + self.indent(); + for (i, dep) in d.iter().enumerate() { + let root_str = match &dep.root { + crate::ManualMemoDependencyRoot::Global { identifier_name } => { + format!("Global(\"{}\")", identifier_name) + } + crate::ManualMemoDependencyRoot::NamedLocal { + value: val, + constant, + } => { + format!( + "NamedLocal({}, constant={})", + val.identifier.0, constant + ) + } + }; + let path_str: String = dep + .path + .iter() + .map(|p| { + format!( + "{}.{}", + if p.optional { "?" } else { "" }, + format_property_literal(&p.property) + ) + }) + .collect(); + self.line(&format!("[{}] {}{}", i, root_str, path_str)); + } + self.dedent(); + } + None => self.line("deps: null"), + } + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + InstructionValue::FinishMemoize { + manual_memo_id, + decl, + pruned, + loc, + } => { + self.line("FinishMemoize {"); + self.indent(); + self.line(&format!("manualMemoId: {}", manual_memo_id)); + self.format_place_field("decl", decl); + self.line(&format!("pruned: {}", pruned)); + self.line(&format!("loc: {}", format_loc(loc))); + self.dedent(); + self.line("}"); + } + } + } + + // ========================================================================= + // Errors + // ========================================================================= + + pub fn format_errors(&mut self, error: &CompilerError) { + if error.details.is_empty() { + self.line("Errors: []"); + return; + } + self.line("Errors:"); + self.indent(); + for (i, detail) in error.details.iter().enumerate() { + self.line(&format!("[{}] {{", i)); + self.indent(); + match detail { + CompilerErrorOrDiagnostic::Diagnostic(d) => { + self.line(&format!("severity: {:?}", d.severity())); + self.line(&format!("reason: {:?}", d.reason)); + self.line(&format!( + "description: {}", + match &d.description { + Some(desc) => format!("{:?}", desc), + None => "null".to_string(), + } + )); + self.line(&format!("category: {:?}", d.category)); + let loc = d.primary_location(); + self.line(&format!( + "loc: {}", + match loc { + Some(l) => format_loc_value(l), + None => "null".to_string(), + } + )); + } + CompilerErrorOrDiagnostic::ErrorDetail(d) => { + self.line(&format!("severity: {:?}", d.severity())); + self.line(&format!("reason: {:?}", d.reason)); + self.line(&format!( + "description: {}", + match &d.description { + Some(desc) => format!("{:?}", desc), + None => "null".to_string(), + } + )); + self.line(&format!("category: {:?}", d.category)); + self.line(&format!( + "loc: {}", + match &d.loc { + Some(l) => format_loc_value(l), + None => "null".to_string(), + } + )); + } + } + self.dedent(); + self.line("}"); + } + self.dedent(); + } +} diff --git a/crates/react_compiler_hir/src/reactive.rs b/crates/react_compiler_hir/src/reactive.rs new file mode 100644 index 000000000000..12d17dfb1e39 --- /dev/null +++ b/crates/react_compiler_hir/src/reactive.rs @@ -0,0 +1,248 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Reactive function types — tree representation of a compiled function. +//! +//! `ReactiveFunction` is derived from the HIR CFG by `BuildReactiveFunction`. +//! Control flow constructs (if/switch/loops/try) and reactive scopes become +//! nested blocks rather than block references. +//! +//! Corresponds to the reactive types in `HIR.ts`. + +use react_compiler_diagnostics::SourceLocation; + +use crate::{ + AliasingEffect, BlockId, EvaluationOrder, InstructionValue, LogicalOperator, ParamPattern, + Place, ScopeId, +}; + +// ============================================================================= +// ReactiveFunction +// ============================================================================= + +/// Tree representation of a compiled function, converted from the CFG-based +/// HIR. TS: ReactiveFunction in HIR.ts +#[derive(Debug, Clone)] +pub struct ReactiveFunction { + pub loc: Option, + pub id: Option, + pub name_hint: Option, + pub params: Vec, + pub generator: bool, + pub is_async: bool, + pub body: ReactiveBlock, + pub directives: Vec, + // No env field — passed separately per established Rust convention +} + +// ============================================================================= +// ReactiveBlock and ReactiveStatement +// ============================================================================= + +/// TS: ReactiveBlock = Array +pub type ReactiveBlock = Vec; + +/// TS: ReactiveStatement (discriminated union with 'kind' field) +#[derive(Debug, Clone)] +pub enum ReactiveStatement { + Instruction(ReactiveInstruction), + Terminal(ReactiveTerminalStatement), + Scope(ReactiveScopeBlock), + PrunedScope(PrunedReactiveScopeBlock), +} + +// ============================================================================= +// ReactiveInstruction and ReactiveValue +// ============================================================================= + +/// TS: ReactiveInstruction +#[derive(Debug, Clone)] +pub struct ReactiveInstruction { + pub id: EvaluationOrder, + pub lvalue: Option, + pub value: ReactiveValue, + pub effects: Option>, + pub loc: Option, +} + +/// Extends InstructionValue with compound expression types that were +/// separate blocks+terminals in HIR but become nested expressions here. +/// TS: ReactiveValue = InstructionValue | ReactiveLogicalValue | ... +#[derive(Debug, Clone)] +pub enum ReactiveValue { + /// All ~35 base instruction value kinds + Instruction(InstructionValue), + + /// TS: ReactiveLogicalValue + LogicalExpression { + operator: LogicalOperator, + left: Box, + right: Box, + loc: Option, + }, + + /// TS: ReactiveTernaryValue + ConditionalExpression { + test: Box, + consequent: Box, + alternate: Box, + loc: Option, + }, + + /// TS: ReactiveSequenceValue + SequenceExpression { + instructions: Vec, + id: EvaluationOrder, + value: Box, + loc: Option, + }, + + /// TS: ReactiveOptionalCallValue + OptionalExpression { + id: EvaluationOrder, + value: Box, + optional: bool, + loc: Option, + }, +} + +// ============================================================================= +// Terminals +// ============================================================================= + +#[derive(Debug, Clone)] +pub struct ReactiveTerminalStatement { + pub terminal: ReactiveTerminal, + pub label: Option, +} + +#[derive(Debug, Clone)] +pub struct ReactiveLabel { + pub id: BlockId, + pub implicit: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ReactiveTerminalTargetKind { + Implicit, + Labeled, + Unlabeled, +} + +impl std::fmt::Display for ReactiveTerminalTargetKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ReactiveTerminalTargetKind::Implicit => write!(f, "implicit"), + ReactiveTerminalTargetKind::Labeled => write!(f, "labeled"), + ReactiveTerminalTargetKind::Unlabeled => write!(f, "unlabeled"), + } + } +} + +#[derive(Debug, Clone)] +pub enum ReactiveTerminal { + Break { + target: BlockId, + id: EvaluationOrder, + target_kind: ReactiveTerminalTargetKind, + loc: Option, + }, + Continue { + target: BlockId, + id: EvaluationOrder, + target_kind: ReactiveTerminalTargetKind, + loc: Option, + }, + Return { + value: Place, + id: EvaluationOrder, + loc: Option, + }, + Throw { + value: Place, + id: EvaluationOrder, + loc: Option, + }, + Switch { + test: Place, + cases: Vec, + id: EvaluationOrder, + loc: Option, + }, + DoWhile { + loop_block: ReactiveBlock, + test: ReactiveValue, + id: EvaluationOrder, + loc: Option, + }, + While { + test: ReactiveValue, + loop_block: ReactiveBlock, + id: EvaluationOrder, + loc: Option, + }, + For { + init: ReactiveValue, + test: ReactiveValue, + update: Option, + loop_block: ReactiveBlock, + id: EvaluationOrder, + loc: Option, + }, + ForOf { + init: ReactiveValue, + test: ReactiveValue, + loop_block: ReactiveBlock, + id: EvaluationOrder, + loc: Option, + }, + ForIn { + init: ReactiveValue, + loop_block: ReactiveBlock, + id: EvaluationOrder, + loc: Option, + }, + If { + test: Place, + consequent: ReactiveBlock, + alternate: Option, + id: EvaluationOrder, + loc: Option, + }, + Label { + block: ReactiveBlock, + id: EvaluationOrder, + loc: Option, + }, + Try { + block: ReactiveBlock, + handler_binding: Option, + handler: ReactiveBlock, + id: EvaluationOrder, + loc: Option, + }, +} + +#[derive(Debug, Clone)] +pub struct ReactiveSwitchCase { + pub test: Option, + pub block: Option, +} + +// ============================================================================= +// Scope Blocks +// ============================================================================= + +#[derive(Debug, Clone)] +pub struct ReactiveScopeBlock { + pub scope: ScopeId, + pub instructions: ReactiveBlock, +} + +#[derive(Debug, Clone)] +pub struct PrunedReactiveScopeBlock { + pub scope: ScopeId, + pub instructions: ReactiveBlock, +} diff --git a/crates/react_compiler_hir/src/type_config.rs b/crates/react_compiler_hir/src/type_config.rs new file mode 100644 index 000000000000..06554b82ff57 --- /dev/null +++ b/crates/react_compiler_hir/src/type_config.rs @@ -0,0 +1,212 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Type configuration types, ported from TypeSchema.ts. +//! +//! These are the JSON-serializable config types used by `moduleTypeProvider` +//! and `installTypeConfig` to describe module/function/hook types. + +use indexmap::IndexMap; + +use crate::Effect; + +/// Mirrors TS `ValueKind` enum for use in config. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ValueKind { + Mutable, + Frozen, + Primitive, + #[serde(rename = "maybefrozen")] + MaybeFrozen, + Global, + Context, +} + +/// Mirrors TS `ValueReason` enum for use in config. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] +pub enum ValueReason { + #[serde(rename = "known-return-signature")] + KnownReturnSignature, + #[serde(rename = "state")] + State, + #[serde(rename = "reducer-state")] + ReducerState, + #[serde(rename = "context")] + Context, + #[serde(rename = "effect")] + Effect, + #[serde(rename = "hook-captured")] + HookCaptured, + #[serde(rename = "hook-return")] + HookReturn, + #[serde(rename = "global")] + Global, + #[serde(rename = "jsx-captured")] + JsxCaptured, + #[serde(rename = "store-local")] + StoreLocal, + #[serde(rename = "reactive-function-argument")] + ReactiveFunctionArgument, + #[serde(rename = "other")] + Other, +} + +// ============================================================================= +// Aliasing effect config types (from TypeSchema.ts) +// ============================================================================= + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(tag = "kind")] +pub enum AliasingEffectConfig { + Freeze { + value: String, + reason: ValueReason, + }, + Create { + into: String, + value: ValueKind, + reason: ValueReason, + }, + CreateFrom { + from: String, + into: String, + }, + Assign { + from: String, + into: String, + }, + Alias { + from: String, + into: String, + }, + Capture { + from: String, + into: String, + }, + ImmutableCapture { + from: String, + into: String, + }, + Impure { + place: String, + }, + Mutate { + value: String, + }, + MutateTransitiveConditionally { + value: String, + }, + Apply { + receiver: String, + function: String, + #[serde(rename = "mutatesFunction")] + mutates_function: bool, + args: Vec, + into: String, + }, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(untagged)] +pub enum ApplyArgConfig { + Place(String), + Spread { + #[allow(dead_code)] + kind: ApplyArgSpreadKind, + place: String, + }, + Hole { + #[allow(dead_code)] + kind: ApplyArgHoleKind, + }, +} + +/// Helper enum for tagged serde of `ApplyArgConfig::Spread`. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub enum ApplyArgSpreadKind { + Spread, +} + +/// Helper enum for tagged serde of `ApplyArgConfig::Hole`. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub enum ApplyArgHoleKind { + Hole, +} + +/// Aliasing signature config, the JSON-serializable form. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct AliasingSignatureConfig { + pub receiver: String, + pub params: Vec, + pub rest: Option, + pub returns: String, + pub temporaries: Vec, + pub effects: Vec, +} + +// ============================================================================= +// Type config (from TypeSchema.ts) +// ============================================================================= + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(tag = "kind")] +pub enum TypeConfig { + #[serde(rename = "object")] + Object(ObjectTypeConfig), + #[serde(rename = "function")] + Function(FunctionTypeConfig), + #[serde(rename = "hook")] + Hook(HookTypeConfig), + #[serde(rename = "type")] + TypeReference(TypeReferenceConfig), +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ObjectTypeConfig { + pub properties: Option>, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionTypeConfig { + pub positional_params: Vec, + pub rest_param: Option, + pub callee_effect: Effect, + pub return_type: Box, + pub return_value_kind: ValueKind, + pub no_alias: Option, + pub mutable_only_if_operands_are_mutable: Option, + pub impure: Option, + pub canonical_name: Option, + pub aliasing: Option, + pub known_incompatible: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookTypeConfig { + pub positional_params: Option>, + pub rest_param: Option, + pub return_type: Box, + pub return_value_kind: Option, + pub no_alias: Option, + pub aliasing: Option, + pub known_incompatible: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum BuiltInTypeRef { + Any, + Ref, + Array, + Primitive, + MixedReadonly, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct TypeReferenceConfig { + pub name: BuiltInTypeRef, +} diff --git a/crates/react_compiler_hir/src/visitors.rs b/crates/react_compiler_hir/src/visitors.rs new file mode 100644 index 000000000000..689b8e789ad3 --- /dev/null +++ b/crates/react_compiler_hir/src/visitors.rs @@ -0,0 +1,1781 @@ +/** + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +use std::collections::HashMap; + +use crate::{ + environment::Environment, ArrayElement, ArrayPatternElement, BasicBlock, BlockId, HirFunction, + IdentifierId, Instruction, InstructionKind, InstructionValue, JsxAttribute, JsxTag, + ManualMemoDependencyRoot, ObjectPropertyKey, ObjectPropertyOrSpread, Pattern, Place, + PlaceOrSpread, ScopeId, Terminal, +}; + +// ============================================================================= +// Iterator functions (return Vec instead of generators) +// ============================================================================= + +/// Yields `instr.lvalue` plus the value's lvalues. +/// Equivalent to TS `eachInstructionLValue`. +pub fn each_instruction_lvalue(instr: &Instruction) -> Vec { + let mut result = Vec::new(); + result.push(instr.lvalue.clone()); + result.extend(each_instruction_value_lvalue(&instr.value)); + result +} + +/// Yields lvalues from +/// DeclareLocal/StoreLocal/DeclareContext/StoreContext/Destructure/ +/// PostfixUpdate/PrefixUpdate. Equivalent to TS `eachInstructionValueLValue`. +pub fn each_instruction_value_lvalue(value: &InstructionValue) -> Vec { + let mut result = Vec::new(); + match value { + InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } + | InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + result.push(lvalue.place.clone()); + } + InstructionValue::Destructure { lvalue, .. } => { + result.extend(each_pattern_operand(&lvalue.pattern)); + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + result.push(lvalue.clone()); + } + // All other variants have no lvalues + InstructionValue::LoadLocal { .. } + | InstructionValue::LoadContext { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::NewExpression { .. } + | InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } + | InstructionValue::UnaryExpression { .. } + | InstructionValue::TypeCastExpression { .. } + | InstructionValue::JsxExpression { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::ObjectMethod { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::JsxFragment { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::PropertyStore { .. } + | InstructionValue::PropertyLoad { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::ComputedStore { .. } + | InstructionValue::ComputedLoad { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::StoreGlobal { .. } + | InstructionValue::FunctionExpression { .. } + | InstructionValue::TaggedTemplateExpression { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::Await { .. } + | InstructionValue::GetIterator { .. } + | InstructionValue::IteratorNext { .. } + | InstructionValue::NextPropertyOf { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::StartMemoize { .. } + | InstructionValue::FinishMemoize { .. } + | InstructionValue::UnsupportedNode { .. } => {} + } + result +} + +/// Yields lvalues with their InstructionKind. +/// Equivalent to TS `eachInstructionLValueWithKind`. +pub fn each_instruction_lvalue_with_kind( + value: &InstructionValue, +) -> Vec<(Place, InstructionKind)> { + let mut result = Vec::new(); + match value { + InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } + | InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + result.push((lvalue.place.clone(), lvalue.kind)); + } + InstructionValue::Destructure { lvalue, .. } => { + let kind = lvalue.kind; + for place in each_pattern_operand(&lvalue.pattern) { + result.push((place, kind)); + } + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + result.push((lvalue.clone(), InstructionKind::Reassign)); + } + // All other variants have no lvalues with kind + InstructionValue::LoadLocal { .. } + | InstructionValue::LoadContext { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::NewExpression { .. } + | InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } + | InstructionValue::UnaryExpression { .. } + | InstructionValue::TypeCastExpression { .. } + | InstructionValue::JsxExpression { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::ObjectMethod { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::JsxFragment { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::PropertyStore { .. } + | InstructionValue::PropertyLoad { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::ComputedStore { .. } + | InstructionValue::ComputedLoad { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::StoreGlobal { .. } + | InstructionValue::FunctionExpression { .. } + | InstructionValue::TaggedTemplateExpression { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::Await { .. } + | InstructionValue::GetIterator { .. } + | InstructionValue::IteratorNext { .. } + | InstructionValue::NextPropertyOf { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::StartMemoize { .. } + | InstructionValue::FinishMemoize { .. } + | InstructionValue::UnsupportedNode { .. } => {} + } + result +} + +/// Delegates to each_instruction_value_operand. +/// Equivalent to TS `eachInstructionOperand`. +pub fn each_instruction_operand(instr: &Instruction, env: &Environment) -> Vec { + each_instruction_value_operand(&instr.value, env) +} + +/// Like `each_instruction_operand` but takes `functions` directly instead of +/// `env`. Useful when borrow splitting prevents passing the full `Environment`. +pub fn each_instruction_operand_with_functions( + instr: &Instruction, + functions: &[HirFunction], +) -> Vec { + each_instruction_value_operand_with_functions(&instr.value, functions) +} + +/// Yields operand places from an InstructionValue. +/// Equivalent to TS `eachInstructionValueOperand`. +pub fn each_instruction_value_operand(value: &InstructionValue, env: &Environment) -> Vec { + each_instruction_value_operand_with_functions(value, &env.functions) +} + +/// Like `each_instruction_value_operand` but takes `functions` directly instead +/// of `env`. Useful when borrow splitting prevents passing the full +/// `Environment`. +pub fn each_instruction_value_operand_with_functions( + value: &InstructionValue, + functions: &[HirFunction], +) -> Vec { + let mut result = Vec::new(); + match value { + InstructionValue::NewExpression { callee, args, .. } + | InstructionValue::CallExpression { callee, args, .. } => { + result.push(callee.clone()); + result.extend(each_call_argument(args)); + } + InstructionValue::BinaryExpression { left, right, .. } => { + result.push(left.clone()); + result.push(right.clone()); + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + result.push(receiver.clone()); + result.push(property.clone()); + result.extend(each_call_argument(args)); + } + InstructionValue::DeclareContext { .. } | InstructionValue::DeclareLocal { .. } => { + // no operands + } + InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => { + result.push(place.clone()); + } + InstructionValue::StoreLocal { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::StoreContext { + lvalue, value: val, .. + } => { + result.push(lvalue.place.clone()); + result.push(val.clone()); + } + InstructionValue::StoreGlobal { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::Destructure { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::PropertyLoad { object, .. } => { + result.push(object.clone()); + } + InstructionValue::PropertyDelete { object, .. } => { + result.push(object.clone()); + } + InstructionValue::PropertyStore { + object, value: val, .. + } => { + result.push(object.clone()); + result.push(val.clone()); + } + InstructionValue::ComputedLoad { + object, property, .. + } => { + result.push(object.clone()); + result.push(property.clone()); + } + InstructionValue::ComputedDelete { + object, property, .. + } => { + result.push(object.clone()); + result.push(property.clone()); + } + InstructionValue::ComputedStore { + object, + property, + value: val, + .. + } => { + result.push(object.clone()); + result.push(property.clone()); + result.push(val.clone()); + } + InstructionValue::UnaryExpression { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::JsxExpression { + tag, + props, + children, + .. + } => { + if let JsxTag::Place(place) = tag { + result.push(place.clone()); + } + for attribute in props { + match attribute { + JsxAttribute::Attribute { place, .. } => { + result.push(place.clone()); + } + JsxAttribute::SpreadAttribute { argument, .. } => { + result.push(argument.clone()); + } + } + } + if let Some(children) = children { + for child in children { + result.push(child.clone()); + } + } + } + InstructionValue::JsxFragment { children, .. } => { + for child in children { + result.push(child.clone()); + } + } + InstructionValue::ObjectExpression { properties, .. } => { + for property in properties { + match property { + ObjectPropertyOrSpread::Property(prop) => { + if let ObjectPropertyKey::Computed { name } = &prop.key { + result.push(name.clone()); + } + result.push(prop.place.clone()); + } + ObjectPropertyOrSpread::Spread(spread) => { + result.push(spread.place.clone()); + } + } + } + } + InstructionValue::ArrayExpression { elements, .. } => { + for element in elements { + match element { + ArrayElement::Place(place) => { + result.push(place.clone()); + } + ArrayElement::Spread(spread) => { + result.push(spread.place.clone()); + } + ArrayElement::Hole => {} + } + } + } + InstructionValue::ObjectMethod { lowered_func, .. } + | InstructionValue::FunctionExpression { lowered_func, .. } => { + let func = &functions[lowered_func.func.0 as usize]; + for ctx_place in &func.context { + result.push(ctx_place.clone()); + } + } + InstructionValue::TaggedTemplateExpression { tag, .. } => { + result.push(tag.clone()); + } + InstructionValue::TypeCastExpression { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::TemplateLiteral { subexprs, .. } => { + for subexpr in subexprs { + result.push(subexpr.clone()); + } + } + InstructionValue::Await { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::GetIterator { collection, .. } => { + result.push(collection.clone()); + } + InstructionValue::IteratorNext { + iterator, + collection, + .. + } => { + result.push(iterator.clone()); + result.push(collection.clone()); + } + InstructionValue::NextPropertyOf { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::PostfixUpdate { value: val, .. } + | InstructionValue::PrefixUpdate { value: val, .. } => { + result.push(val.clone()); + } + InstructionValue::StartMemoize { deps, .. } => { + if let Some(deps) = deps { + for dep in deps { + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &dep.root { + result.push(value.clone()); + } + } + } + } + InstructionValue::FinishMemoize { decl, .. } => { + result.push(decl.clone()); + } + InstructionValue::Debugger { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::UnsupportedNode { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } => { + // no operands + } + } + result +} + +/// Yields each arg's place. +/// Equivalent to TS `eachCallArgument`. +pub fn each_call_argument(args: &[PlaceOrSpread]) -> Vec { + let mut result = Vec::new(); + for arg in args { + match arg { + PlaceOrSpread::Place(place) => { + result.push(place.clone()); + } + PlaceOrSpread::Spread(spread) => { + result.push(spread.place.clone()); + } + } + } + result +} + +/// Yields places from array/object patterns. +/// Equivalent to TS `eachPatternOperand`. +pub fn each_pattern_operand(pattern: &Pattern) -> Vec { + let mut result = Vec::new(); + match pattern { + Pattern::Array(arr) => { + for item in &arr.items { + match item { + ArrayPatternElement::Place(place) => { + result.push(place.clone()); + } + ArrayPatternElement::Spread(spread) => { + result.push(spread.place.clone()); + } + ArrayPatternElement::Hole => {} + } + } + } + Pattern::Object(obj) => { + for property in &obj.properties { + match property { + ObjectPropertyOrSpread::Property(prop) => { + result.push(prop.place.clone()); + } + ObjectPropertyOrSpread::Spread(spread) => { + result.push(spread.place.clone()); + } + } + } + } + } + result +} + +/// Returns true if the pattern contains a spread element. +/// Equivalent to TS `doesPatternContainSpreadElement`. +pub fn does_pattern_contain_spread_element(pattern: &Pattern) -> bool { + match pattern { + Pattern::Array(arr) => { + for item in &arr.items { + if matches!(item, ArrayPatternElement::Spread(_)) { + return true; + } + } + } + Pattern::Object(obj) => { + for property in &obj.properties { + if matches!(property, ObjectPropertyOrSpread::Spread(_)) { + return true; + } + } + } + } + false +} + +/// Yields successor block IDs (NOT fallthroughs, this is intentional). +/// Equivalent to TS `eachTerminalSuccessor`. +pub fn each_terminal_successor(terminal: &Terminal) -> Vec { + let mut result = Vec::new(); + match terminal { + Terminal::Goto { block, .. } => { + result.push(*block); + } + Terminal::If { + consequent, + alternate, + .. + } => { + result.push(*consequent); + result.push(*alternate); + } + Terminal::Branch { + consequent, + alternate, + .. + } => { + result.push(*consequent); + result.push(*alternate); + } + Terminal::Switch { cases, .. } => { + for case in cases { + result.push(case.block); + } + } + Terminal::Optional { test, .. } + | Terminal::Ternary { test, .. } + | Terminal::Logical { test, .. } => { + result.push(*test); + } + Terminal::Return { .. } => {} + Terminal::Throw { .. } => {} + Terminal::DoWhile { loop_block, .. } => { + result.push(*loop_block); + } + Terminal::While { test, .. } => { + result.push(*test); + } + Terminal::For { init, .. } => { + result.push(*init); + } + Terminal::ForOf { init, .. } => { + result.push(*init); + } + Terminal::ForIn { init, .. } => { + result.push(*init); + } + Terminal::Label { block, .. } => { + result.push(*block); + } + Terminal::Sequence { block, .. } => { + result.push(*block); + } + Terminal::MaybeThrow { + continuation, + handler, + .. + } => { + result.push(*continuation); + if let Some(handler) = handler { + result.push(*handler); + } + } + Terminal::Try { block, .. } => { + result.push(*block); + } + Terminal::Scope { block, .. } | Terminal::PrunedScope { block, .. } => { + result.push(*block); + } + Terminal::Unreachable { .. } | Terminal::Unsupported { .. } => {} + } + result +} + +/// Yields places used by terminal. +/// Equivalent to TS `eachTerminalOperand`. +pub fn each_terminal_operand(terminal: &Terminal) -> Vec { + let mut result = Vec::new(); + match terminal { + Terminal::If { test, .. } => { + result.push(test.clone()); + } + Terminal::Branch { test, .. } => { + result.push(test.clone()); + } + Terminal::Switch { test, cases, .. } => { + result.push(test.clone()); + for case in cases { + if let Some(test) = &case.test { + result.push(test.clone()); + } + } + } + Terminal::Return { value, .. } | Terminal::Throw { value, .. } => { + result.push(value.clone()); + } + Terminal::Try { + handler_binding, .. + } => { + if let Some(binding) = handler_binding { + result.push(binding.clone()); + } + } + Terminal::MaybeThrow { .. } + | Terminal::Sequence { .. } + | Terminal::Label { .. } + | Terminal::Optional { .. } + | Terminal::Ternary { .. } + | Terminal::Logical { .. } + | Terminal::DoWhile { .. } + | Terminal::While { .. } + | Terminal::For { .. } + | Terminal::ForOf { .. } + | Terminal::ForIn { .. } + | Terminal::Goto { .. } + | Terminal::Unreachable { .. } + | Terminal::Unsupported { .. } + | Terminal::Scope { .. } + | Terminal::PrunedScope { .. } => { + // no-op + } + } + result +} + +// ============================================================================= +// Mapping functions (mutate in place) +// ============================================================================= + +/// Maps the instruction's lvalue and value's lvalues. +/// Equivalent to TS `mapInstructionLValues`. +pub fn map_instruction_lvalues(instr: &mut Instruction, f: &mut impl FnMut(Place) -> Place) { + match &mut instr.value { + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } + | InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } => { + lvalue.place = f(lvalue.place.clone()); + } + InstructionValue::Destructure { lvalue, .. } => { + map_pattern_operands(&mut lvalue.pattern, f); + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + *lvalue = f(lvalue.clone()); + } + _ => {} + } + instr.lvalue = f(instr.lvalue.clone()); +} + +/// Maps operands of an instruction. +/// Equivalent to TS `mapInstructionOperands`. +pub fn map_instruction_operands( + instr: &mut Instruction, + env: &mut Environment, + f: &mut impl FnMut(Place) -> Place, +) { + map_instruction_value_operands(&mut instr.value, env, f); +} + +/// Maps operand places in an InstructionValue. +/// Equivalent to TS `mapInstructionValueOperands`. +pub fn map_instruction_value_operands( + value: &mut InstructionValue, + env: &mut Environment, + f: &mut impl FnMut(Place) -> Place, +) { + match value { + InstructionValue::BinaryExpression { left, right, .. } => { + *left = f(left.clone()); + *right = f(right.clone()); + } + InstructionValue::PropertyLoad { object, .. } => { + *object = f(object.clone()); + } + InstructionValue::PropertyDelete { object, .. } => { + *object = f(object.clone()); + } + InstructionValue::PropertyStore { + object, value: val, .. + } => { + *object = f(object.clone()); + *val = f(val.clone()); + } + InstructionValue::ComputedLoad { + object, property, .. + } => { + *object = f(object.clone()); + *property = f(property.clone()); + } + InstructionValue::ComputedDelete { + object, property, .. + } => { + *object = f(object.clone()); + *property = f(property.clone()); + } + InstructionValue::ComputedStore { + object, + property, + value: val, + .. + } => { + *object = f(object.clone()); + *property = f(property.clone()); + *val = f(val.clone()); + } + InstructionValue::DeclareContext { .. } | InstructionValue::DeclareLocal { .. } => { + // no operands + } + InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => { + *place = f(place.clone()); + } + InstructionValue::StoreLocal { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::StoreContext { + lvalue, value: val, .. + } => { + lvalue.place = f(lvalue.place.clone()); + *val = f(val.clone()); + } + InstructionValue::StoreGlobal { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::Destructure { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::NewExpression { callee, args, .. } + | InstructionValue::CallExpression { callee, args, .. } => { + *callee = f(callee.clone()); + map_call_arguments(args, f); + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + *receiver = f(receiver.clone()); + *property = f(property.clone()); + map_call_arguments(args, f); + } + InstructionValue::UnaryExpression { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::JsxExpression { + tag, + props, + children, + .. + } => { + if let JsxTag::Place(place) = tag { + *place = f(place.clone()); + } + for attribute in props.iter_mut() { + match attribute { + JsxAttribute::Attribute { place, .. } => { + *place = f(place.clone()); + } + JsxAttribute::SpreadAttribute { argument, .. } => { + *argument = f(argument.clone()); + } + } + } + if let Some(children) = children { + *children = children.iter().map(|p| f(p.clone())).collect(); + } + } + InstructionValue::ObjectExpression { properties, .. } => { + for property in properties.iter_mut() { + match property { + ObjectPropertyOrSpread::Property(prop) => { + if let ObjectPropertyKey::Computed { name } = &mut prop.key { + *name = f(name.clone()); + } + prop.place = f(prop.place.clone()); + } + ObjectPropertyOrSpread::Spread(spread) => { + spread.place = f(spread.place.clone()); + } + } + } + } + InstructionValue::ArrayExpression { elements, .. } => { + *elements = elements + .iter() + .map(|element| match element { + ArrayElement::Place(place) => ArrayElement::Place(f(place.clone())), + ArrayElement::Spread(spread) => { + let mut spread = spread.clone(); + spread.place = f(spread.place.clone()); + ArrayElement::Spread(spread) + } + ArrayElement::Hole => ArrayElement::Hole, + }) + .collect(); + } + InstructionValue::JsxFragment { children, .. } => { + *children = children.iter().map(|e| f(e.clone())).collect(); + } + InstructionValue::ObjectMethod { lowered_func, .. } + | InstructionValue::FunctionExpression { lowered_func, .. } => { + let func = &mut env.functions[lowered_func.func.0 as usize]; + func.context = func.context.iter().map(|d| f(d.clone())).collect(); + } + InstructionValue::TaggedTemplateExpression { tag, .. } => { + *tag = f(tag.clone()); + } + InstructionValue::TypeCastExpression { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::TemplateLiteral { subexprs, .. } => { + *subexprs = subexprs.iter().map(|s| f(s.clone())).collect(); + } + InstructionValue::Await { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::GetIterator { collection, .. } => { + *collection = f(collection.clone()); + } + InstructionValue::IteratorNext { + iterator, + collection, + .. + } => { + *iterator = f(iterator.clone()); + *collection = f(collection.clone()); + } + InstructionValue::NextPropertyOf { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::PostfixUpdate { value: val, .. } + | InstructionValue::PrefixUpdate { value: val, .. } => { + *val = f(val.clone()); + } + InstructionValue::StartMemoize { deps, .. } => { + if let Some(deps) = deps { + for dep in deps.iter_mut() { + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &mut dep.root { + *value = f(value.clone()); + } + } + } + } + InstructionValue::FinishMemoize { decl, .. } => { + *decl = f(decl.clone()); + } + InstructionValue::Debugger { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::UnsupportedNode { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } => { + // no operands + } + } +} + +/// Maps call arguments in place. +/// Equivalent to TS `mapCallArguments`. +pub fn map_call_arguments(args: &mut Vec, f: &mut impl FnMut(Place) -> Place) { + for arg in args.iter_mut() { + match arg { + PlaceOrSpread::Place(place) => { + *place = f(place.clone()); + } + PlaceOrSpread::Spread(spread) => { + spread.place = f(spread.place.clone()); + } + } + } +} + +/// Maps pattern operands in place. +/// Equivalent to TS `mapPatternOperands`. +pub fn map_pattern_operands(pattern: &mut Pattern, f: &mut impl FnMut(Place) -> Place) { + match pattern { + Pattern::Array(arr) => { + arr.items = arr + .items + .iter() + .map(|item| match item { + ArrayPatternElement::Place(place) => { + ArrayPatternElement::Place(f(place.clone())) + } + ArrayPatternElement::Spread(spread) => { + let mut spread = spread.clone(); + spread.place = f(spread.place.clone()); + ArrayPatternElement::Spread(spread) + } + ArrayPatternElement::Hole => ArrayPatternElement::Hole, + }) + .collect(); + } + Pattern::Object(obj) => { + for property in obj.properties.iter_mut() { + match property { + ObjectPropertyOrSpread::Property(prop) => { + prop.place = f(prop.place.clone()); + } + ObjectPropertyOrSpread::Spread(spread) => { + spread.place = f(spread.place.clone()); + } + } + } + } + } +} + +/// Maps a terminal node's block assignments in place. +/// Equivalent to TS `mapTerminalSuccessors` — but mutates in place instead of +/// returning a new terminal. +pub fn map_terminal_successors(terminal: &mut Terminal, f: &mut impl FnMut(BlockId) -> BlockId) { + match terminal { + Terminal::Goto { block, .. } => { + *block = f(*block); + } + Terminal::If { + consequent, + alternate, + fallthrough, + .. + } => { + *consequent = f(*consequent); + *alternate = f(*alternate); + *fallthrough = f(*fallthrough); + } + Terminal::Branch { + consequent, + alternate, + fallthrough, + .. + } => { + *consequent = f(*consequent); + *alternate = f(*alternate); + *fallthrough = f(*fallthrough); + } + Terminal::Switch { + cases, fallthrough, .. + } => { + for case in cases.iter_mut() { + case.block = f(case.block); + } + *fallthrough = f(*fallthrough); + } + Terminal::Logical { + test, fallthrough, .. + } => { + *test = f(*test); + *fallthrough = f(*fallthrough); + } + Terminal::Ternary { + test, fallthrough, .. + } => { + *test = f(*test); + *fallthrough = f(*fallthrough); + } + Terminal::Optional { + test, fallthrough, .. + } => { + *test = f(*test); + *fallthrough = f(*fallthrough); + } + Terminal::Return { .. } => {} + Terminal::Throw { .. } => {} + Terminal::DoWhile { + loop_block, + test, + fallthrough, + .. + } => { + *loop_block = f(*loop_block); + *test = f(*test); + *fallthrough = f(*fallthrough); + } + Terminal::While { + test, + loop_block, + fallthrough, + .. + } => { + *test = f(*test); + *loop_block = f(*loop_block); + *fallthrough = f(*fallthrough); + } + Terminal::For { + init, + test, + update, + loop_block, + fallthrough, + .. + } => { + *init = f(*init); + *test = f(*test); + if let Some(update) = update { + *update = f(*update); + } + *loop_block = f(*loop_block); + *fallthrough = f(*fallthrough); + } + Terminal::ForOf { + init, + test, + loop_block, + fallthrough, + .. + } => { + *init = f(*init); + *test = f(*test); + *loop_block = f(*loop_block); + *fallthrough = f(*fallthrough); + } + Terminal::ForIn { + init, + loop_block, + fallthrough, + .. + } => { + *init = f(*init); + *loop_block = f(*loop_block); + *fallthrough = f(*fallthrough); + } + Terminal::Label { + block, fallthrough, .. + } => { + *block = f(*block); + *fallthrough = f(*fallthrough); + } + Terminal::Sequence { + block, fallthrough, .. + } => { + *block = f(*block); + *fallthrough = f(*fallthrough); + } + Terminal::MaybeThrow { + continuation, + handler, + .. + } => { + *continuation = f(*continuation); + if let Some(handler) = handler { + *handler = f(*handler); + } + } + Terminal::Try { + block, + handler, + fallthrough, + .. + } => { + *block = f(*block); + *handler = f(*handler); + *fallthrough = f(*fallthrough); + } + Terminal::Scope { + block, fallthrough, .. + } + | Terminal::PrunedScope { + block, fallthrough, .. + } => { + *block = f(*block); + *fallthrough = f(*fallthrough); + } + Terminal::Unreachable { .. } | Terminal::Unsupported { .. } => {} + } +} + +/// Maps a terminal node's operand places in place. +/// Equivalent to TS `mapTerminalOperands`. +pub fn map_terminal_operands(terminal: &mut Terminal, f: &mut impl FnMut(Place) -> Place) { + match terminal { + Terminal::If { test, .. } => { + *test = f(test.clone()); + } + Terminal::Branch { test, .. } => { + *test = f(test.clone()); + } + Terminal::Switch { test, cases, .. } => { + *test = f(test.clone()); + for case in cases.iter_mut() { + if let Some(t) = &mut case.test { + *t = f(t.clone()); + } + } + } + Terminal::Return { value, .. } | Terminal::Throw { value, .. } => { + *value = f(value.clone()); + } + Terminal::Try { + handler_binding, .. + } => { + if let Some(binding) = handler_binding { + *binding = f(binding.clone()); + } + } + Terminal::MaybeThrow { .. } + | Terminal::Sequence { .. } + | Terminal::Label { .. } + | Terminal::Optional { .. } + | Terminal::Ternary { .. } + | Terminal::Logical { .. } + | Terminal::DoWhile { .. } + | Terminal::While { .. } + | Terminal::For { .. } + | Terminal::ForOf { .. } + | Terminal::ForIn { .. } + | Terminal::Goto { .. } + | Terminal::Unreachable { .. } + | Terminal::Unsupported { .. } + | Terminal::Scope { .. } + | Terminal::PrunedScope { .. } => { + // no-op + } + } +} + +/// Yields ALL block IDs referenced by a terminal (successors + fallthroughs + +/// internal blocks). Unlike `each_terminal_successor` which yields only +/// standard control flow successors, this function yields every block ID that +/// `map_terminal_successors` would visit. +pub fn each_terminal_all_successors(terminal: &Terminal) -> Vec { + let mut result = Vec::new(); + match terminal { + Terminal::Goto { block, .. } => { + result.push(*block); + } + Terminal::If { + consequent, + alternate, + fallthrough, + .. + } => { + result.push(*consequent); + result.push(*alternate); + result.push(*fallthrough); + } + Terminal::Branch { + consequent, + alternate, + fallthrough, + .. + } => { + result.push(*consequent); + result.push(*alternate); + result.push(*fallthrough); + } + Terminal::Switch { + cases, fallthrough, .. + } => { + for case in cases { + result.push(case.block); + } + result.push(*fallthrough); + } + Terminal::Logical { + test, fallthrough, .. + } + | Terminal::Ternary { + test, fallthrough, .. + } + | Terminal::Optional { + test, fallthrough, .. + } => { + result.push(*test); + result.push(*fallthrough); + } + Terminal::Return { .. } | Terminal::Throw { .. } => {} + Terminal::DoWhile { + loop_block, + test, + fallthrough, + .. + } => { + result.push(*loop_block); + result.push(*test); + result.push(*fallthrough); + } + Terminal::While { + test, + loop_block, + fallthrough, + .. + } => { + result.push(*test); + result.push(*loop_block); + result.push(*fallthrough); + } + Terminal::For { + init, + test, + update, + loop_block, + fallthrough, + .. + } => { + result.push(*init); + result.push(*test); + if let Some(update) = update { + result.push(*update); + } + result.push(*loop_block); + result.push(*fallthrough); + } + Terminal::ForOf { + init, + test, + loop_block, + fallthrough, + .. + } => { + result.push(*init); + result.push(*test); + result.push(*loop_block); + result.push(*fallthrough); + } + Terminal::ForIn { + init, + loop_block, + fallthrough, + .. + } => { + result.push(*init); + result.push(*loop_block); + result.push(*fallthrough); + } + Terminal::Label { + block, fallthrough, .. + } + | Terminal::Sequence { + block, fallthrough, .. + } => { + result.push(*block); + result.push(*fallthrough); + } + Terminal::MaybeThrow { + continuation, + handler, + .. + } => { + result.push(*continuation); + if let Some(handler) = handler { + result.push(*handler); + } + } + Terminal::Try { + block, + handler, + fallthrough, + .. + } => { + result.push(*block); + result.push(*handler); + result.push(*fallthrough); + } + Terminal::Scope { + block, fallthrough, .. + } + | Terminal::PrunedScope { + block, fallthrough, .. + } => { + result.push(*block); + result.push(*fallthrough); + } + Terminal::Unreachable { .. } | Terminal::Unsupported { .. } => {} + } + result +} + +// ============================================================================= +// Terminal fallthrough functions +// ============================================================================= + +/// Returns the fallthrough block ID for terminals that have one. +/// Equivalent to TS `terminalFallthrough`. +pub fn terminal_fallthrough(terminal: &Terminal) -> Option { + match terminal { + // These terminals do NOT have a fallthrough + Terminal::MaybeThrow { .. } + | Terminal::Goto { .. } + | Terminal::Return { .. } + | Terminal::Throw { .. } + | Terminal::Unreachable { .. } + | Terminal::Unsupported { .. } => None, + + // These terminals DO have a fallthrough + Terminal::Branch { fallthrough, .. } + | Terminal::Try { fallthrough, .. } + | Terminal::DoWhile { fallthrough, .. } + | Terminal::ForOf { fallthrough, .. } + | Terminal::ForIn { fallthrough, .. } + | Terminal::For { fallthrough, .. } + | Terminal::If { fallthrough, .. } + | Terminal::Label { fallthrough, .. } + | Terminal::Logical { fallthrough, .. } + | Terminal::Optional { fallthrough, .. } + | Terminal::Sequence { fallthrough, .. } + | Terminal::Switch { fallthrough, .. } + | Terminal::Ternary { fallthrough, .. } + | Terminal::While { fallthrough, .. } + | Terminal::Scope { fallthrough, .. } + | Terminal::PrunedScope { fallthrough, .. } => Some(*fallthrough), + } +} + +/// Returns true if the terminal has a fallthrough block. +/// Equivalent to TS `terminalHasFallthrough`. +pub fn terminal_has_fallthrough(terminal: &Terminal) -> bool { + terminal_fallthrough(terminal).is_some() +} + +// ============================================================================= +// ScopeBlockTraversal +// ============================================================================= + +/// Block info entry for ScopeBlockTraversal. +#[derive(Debug, Clone)] +pub enum ScopeBlockInfo { + Begin { + scope: ScopeId, + pruned: bool, + fallthrough: BlockId, + }, + End { + scope: ScopeId, + pruned: bool, + }, +} + +/// Helper struct for traversing scope blocks in HIR-form. +/// Equivalent to TS `ScopeBlockTraversal` class. +pub struct ScopeBlockTraversal { + /// Live stack of active scopes + active_scopes: Vec, + /// Map from block ID to scope block info + pub block_infos: HashMap, +} + +impl ScopeBlockTraversal { + pub fn new() -> Self { + ScopeBlockTraversal { + active_scopes: Vec::new(), + block_infos: HashMap::new(), + } + } + + /// Record scope information for a block's terminal. + /// Equivalent to TS `recordScopes`. + pub fn record_scopes(&mut self, block: &BasicBlock) { + if let Some(block_info) = self.block_infos.get(&block.id) { + match block_info { + ScopeBlockInfo::Begin { scope, .. } => { + self.active_scopes.push(*scope); + } + ScopeBlockInfo::End { scope, .. } => { + let top = self.active_scopes.last(); + assert_eq!( + Some(scope), + top, + "Expected traversed block fallthrough to match top-most active scope" + ); + self.active_scopes.pop(); + } + } + } + + match &block.terminal { + Terminal::Scope { + block: scope_block, + fallthrough, + scope, + .. + } => { + assert!( + !self.block_infos.contains_key(scope_block) + && !self.block_infos.contains_key(fallthrough), + "Expected unique scope blocks and fallthroughs" + ); + self.block_infos.insert( + *scope_block, + ScopeBlockInfo::Begin { + scope: *scope, + pruned: false, + fallthrough: *fallthrough, + }, + ); + self.block_infos.insert( + *fallthrough, + ScopeBlockInfo::End { + scope: *scope, + pruned: false, + }, + ); + } + Terminal::PrunedScope { + block: scope_block, + fallthrough, + scope, + .. + } => { + assert!( + !self.block_infos.contains_key(scope_block) + && !self.block_infos.contains_key(fallthrough), + "Expected unique scope blocks and fallthroughs" + ); + self.block_infos.insert( + *scope_block, + ScopeBlockInfo::Begin { + scope: *scope, + pruned: true, + fallthrough: *fallthrough, + }, + ); + self.block_infos.insert( + *fallthrough, + ScopeBlockInfo::End { + scope: *scope, + pruned: true, + }, + ); + } + _ => {} + } + } + + /// Returns true if the given scope is currently 'active', i.e. if the scope + /// start block but not the scope fallthrough has been recorded. + pub fn is_scope_active(&self, scope_id: ScopeId) -> bool { + self.active_scopes.contains(&scope_id) + } + + /// The current, innermost active scope. + pub fn current_scope(&self) -> Option { + self.active_scopes.last().copied() + } +} + +impl Default for ScopeBlockTraversal { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================= +// Convenience wrappers: extract IdentifierIds from Places +// ============================================================================= + +/// Collect all lvalue IdentifierIds from an instruction. +/// Convenience wrapper around `each_instruction_lvalue` that maps to ids. +pub fn each_instruction_lvalue_ids(instr: &Instruction) -> Vec { + each_instruction_lvalue(instr) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Collect all operand IdentifierIds from an instruction. +/// Convenience wrapper around `each_instruction_operand` that maps to ids. +pub fn each_instruction_operand_ids(instr: &Instruction, env: &Environment) -> Vec { + each_instruction_operand(instr, env) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Collect all operand IdentifierIds from an instruction value. +/// Convenience wrapper around `each_instruction_value_operand` that maps to +/// ids. +pub fn each_instruction_value_operand_ids( + value: &InstructionValue, + env: &Environment, +) -> Vec { + each_instruction_value_operand(value, env) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Collect all operand IdentifierIds from a terminal. +/// Convenience wrapper around `each_terminal_operand` that maps to ids. +pub fn each_terminal_operand_ids(terminal: &Terminal) -> Vec { + each_terminal_operand(terminal) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Collect all IdentifierIds from a pattern. +/// Convenience wrapper around `each_pattern_operand` that maps to ids. +pub fn each_pattern_operand_ids(pattern: &Pattern) -> Vec { + each_pattern_operand(pattern) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +// ============================================================================= +// In-place mutation variants (f(&mut Place) callbacks) +// ============================================================================= +// +// These variants use `f(&mut Place)` instead of `f(Place) -> Place`, which is +// more natural for Rust in-place mutation patterns. They do NOT handle +// FunctionExpression/ObjectMethod context (since that requires env access). +// Callers that need to process inner function context should handle it +// separately, e.g.: +// +// for_each_instruction_value_operand_mut(&mut instr.value, &mut |place| { ... +// }); if let InstructionValue::FunctionExpression { lowered_func, .. } +// | InstructionValue::ObjectMethod { lowered_func, .. } = &mut +// instr.value { let func = &mut env.functions[lowered_func.func.0 as +// usize]; for ctx in func.context.iter_mut() { ... } +// } +// + +/// In-place mutation of all operand places in an InstructionValue. +/// Does NOT handle FunctionExpression/ObjectMethod context — callers handle +/// those separately. +pub fn for_each_instruction_value_operand_mut( + value: &mut InstructionValue, + f: &mut impl FnMut(&mut Place), +) { + match value { + InstructionValue::BinaryExpression { left, right, .. } => { + f(left); + f(right); + } + InstructionValue::PropertyLoad { object, .. } + | InstructionValue::PropertyDelete { object, .. } => { + f(object); + } + InstructionValue::PropertyStore { + object, value: val, .. + } => { + f(object); + f(val); + } + InstructionValue::ComputedLoad { + object, property, .. + } + | InstructionValue::ComputedDelete { + object, property, .. + } => { + f(object); + f(property); + } + InstructionValue::ComputedStore { + object, + property, + value: val, + .. + } => { + f(object); + f(property); + f(val); + } + InstructionValue::DeclareContext { .. } | InstructionValue::DeclareLocal { .. } => {} + InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => { + f(place); + } + InstructionValue::StoreLocal { value: val, .. } => { + f(val); + } + InstructionValue::StoreContext { + lvalue, value: val, .. + } => { + f(&mut lvalue.place); + f(val); + } + InstructionValue::StoreGlobal { value: val, .. } => { + f(val); + } + InstructionValue::Destructure { value: val, .. } => { + f(val); + } + InstructionValue::NewExpression { callee, args, .. } + | InstructionValue::CallExpression { callee, args, .. } => { + f(callee); + for_each_call_argument_mut(args, f); + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + f(receiver); + f(property); + for_each_call_argument_mut(args, f); + } + InstructionValue::UnaryExpression { value: val, .. } => { + f(val); + } + InstructionValue::JsxExpression { + tag, + props, + children, + .. + } => { + if let JsxTag::Place(place) = tag { + f(place); + } + for attribute in props.iter_mut() { + match attribute { + JsxAttribute::Attribute { place, .. } => f(place), + JsxAttribute::SpreadAttribute { argument, .. } => f(argument), + } + } + if let Some(children) = children { + for child in children.iter_mut() { + f(child); + } + } + } + InstructionValue::ObjectExpression { properties, .. } => { + for property in properties.iter_mut() { + match property { + ObjectPropertyOrSpread::Property(prop) => { + if let ObjectPropertyKey::Computed { name } = &mut prop.key { + f(name); + } + f(&mut prop.place); + } + ObjectPropertyOrSpread::Spread(spread) => { + f(&mut spread.place); + } + } + } + } + InstructionValue::ArrayExpression { elements, .. } => { + for elem in elements.iter_mut() { + match elem { + ArrayElement::Place(p) => f(p), + ArrayElement::Spread(s) => f(&mut s.place), + ArrayElement::Hole => {} + } + } + } + InstructionValue::JsxFragment { children, .. } => { + for child in children.iter_mut() { + f(child); + } + } + InstructionValue::FunctionExpression { .. } | InstructionValue::ObjectMethod { .. } => { + // Context places require env access — callers handle separately. + } + InstructionValue::TaggedTemplateExpression { tag, .. } => { + f(tag); + } + InstructionValue::TypeCastExpression { value: val, .. } => { + f(val); + } + InstructionValue::TemplateLiteral { subexprs, .. } => { + for expr in subexprs.iter_mut() { + f(expr); + } + } + InstructionValue::Await { value: val, .. } => { + f(val); + } + InstructionValue::GetIterator { collection, .. } => { + f(collection); + } + InstructionValue::IteratorNext { + iterator, + collection, + .. + } => { + f(iterator); + f(collection); + } + InstructionValue::NextPropertyOf { value: val, .. } => { + f(val); + } + InstructionValue::PostfixUpdate { value: val, .. } + | InstructionValue::PrefixUpdate { value: val, .. } => { + f(val); + } + InstructionValue::StartMemoize { deps, .. } => { + if let Some(deps) = deps { + for dep in deps.iter_mut() { + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &mut dep.root { + f(value); + } + } + } + } + InstructionValue::FinishMemoize { decl, .. } => { + f(decl); + } + InstructionValue::Debugger { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::UnsupportedNode { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } => {} + } +} + +/// In-place mutation of call arguments. +pub fn for_each_call_argument_mut(args: &mut [PlaceOrSpread], f: &mut impl FnMut(&mut Place)) { + for arg in args.iter_mut() { + match arg { + PlaceOrSpread::Place(place) => f(place), + PlaceOrSpread::Spread(spread) => f(&mut spread.place), + } + } +} + +/// In-place mutation of an InstructionValue's lvalues (DeclareLocal, +/// StoreLocal, DeclareContext, StoreContext, Destructure, PostfixUpdate, +/// PrefixUpdate). Does NOT include the instruction's top-level lvalue — use +/// `for_each_instruction_lvalue_mut` for that. +pub fn for_each_instruction_value_lvalue_mut( + value: &mut InstructionValue, + f: &mut impl FnMut(&mut Place), +) { + match value { + InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } + | InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + f(&mut lvalue.place); + } + InstructionValue::Destructure { lvalue, .. } => { + for_each_pattern_operand_mut(&mut lvalue.pattern, f); + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + f(lvalue); + } + _ => {} + } +} + +/// In-place mutation of the instruction's lvalue and value's lvalues. +/// Matches the same variants as TS `mapInstructionLValues` (skips +/// DeclareContext/StoreContext). +pub fn for_each_instruction_lvalue_mut(instr: &mut Instruction, f: &mut impl FnMut(&mut Place)) { + match &mut instr.value { + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + f(&mut lvalue.place); + } + InstructionValue::Destructure { lvalue, .. } => { + for_each_pattern_operand_mut(&mut lvalue.pattern, f); + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + f(lvalue); + } + _ => {} + } + f(&mut instr.lvalue); +} + +/// In-place mutation of pattern operands. +pub fn for_each_pattern_operand_mut(pattern: &mut Pattern, f: &mut impl FnMut(&mut Place)) { + match pattern { + Pattern::Array(arr) => { + for item in arr.items.iter_mut() { + match item { + ArrayPatternElement::Place(p) => f(p), + ArrayPatternElement::Spread(s) => f(&mut s.place), + ArrayPatternElement::Hole => {} + } + } + } + Pattern::Object(obj) => { + for property in obj.properties.iter_mut() { + match property { + ObjectPropertyOrSpread::Property(prop) => f(&mut prop.place), + ObjectPropertyOrSpread::Spread(spread) => f(&mut spread.place), + } + } + } + } +} + +/// In-place mutation of terminal operand places. +pub fn for_each_terminal_operand_mut(terminal: &mut Terminal, f: &mut impl FnMut(&mut Place)) { + match terminal { + Terminal::If { test, .. } | Terminal::Branch { test, .. } => { + f(test); + } + Terminal::Switch { test, cases, .. } => { + f(test); + for case in cases.iter_mut() { + if let Some(t) = &mut case.test { + f(t); + } + } + } + Terminal::Return { value, .. } | Terminal::Throw { value, .. } => { + f(value); + } + Terminal::Try { + handler_binding, .. + } => { + if let Some(binding) = handler_binding { + f(binding); + } + } + Terminal::MaybeThrow { .. } + | Terminal::Sequence { .. } + | Terminal::Label { .. } + | Terminal::Optional { .. } + | Terminal::Ternary { .. } + | Terminal::Logical { .. } + | Terminal::DoWhile { .. } + | Terminal::While { .. } + | Terminal::For { .. } + | Terminal::ForOf { .. } + | Terminal::ForIn { .. } + | Terminal::Goto { .. } + | Terminal::Unreachable { .. } + | Terminal::Unsupported { .. } + | Terminal::Scope { .. } + | Terminal::PrunedScope { .. } => {} + } +} diff --git a/crates/react_compiler_inference/Cargo.toml b/crates/react_compiler_inference/Cargo.toml new file mode 100644 index 000000000000..b72f13cf6188 --- /dev/null +++ b/crates/react_compiler_inference/Cargo.toml @@ -0,0 +1,16 @@ +[package] +description = "Vendored React Compiler inference passes from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_inference" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_hir = { path = "../react_compiler_hir" } +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_lowering = { path = "../react_compiler_lowering" } +react_compiler_optimization = { path = "../react_compiler_optimization" } +react_compiler_ssa = { path = "../react_compiler_ssa" } +react_compiler_utils = { path = "../react_compiler_utils" } +indexmap = { workspace = true } diff --git a/crates/react_compiler_inference/src/align_method_call_scopes.rs b/crates/react_compiler_inference/src/align_method_call_scopes.rs new file mode 100644 index 000000000000..764f6dda9083 --- /dev/null +++ b/crates/react_compiler_inference/src/align_method_call_scopes.rs @@ -0,0 +1,118 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Ensures that method call instructions have scopes such that either: +//! - Both the MethodCall and its property have the same scope +//! - OR neither has a scope +//! +//! Ported from TypeScript `src/ReactiveScopes/AlignMethodCallScopes.ts`. + +use std::collections::HashMap; + +use react_compiler_hir::{ + environment::Environment, EvaluationOrder, HirFunction, IdentifierId, InstructionValue, ScopeId, +}; +use react_compiler_utils::DisjointSet; + +// ============================================================================= +// Public API +// ============================================================================= + +/// Aligns method call scopes so that either both the MethodCall result and its +/// property operand share the same scope, or neither has a scope. +/// +/// Corresponds to TS `alignMethodCallScopes(fn: HIRFunction): void`. +pub fn align_method_call_scopes(func: &mut HirFunction, env: &mut Environment) { + // Maps an identifier to the scope it should be assigned to (or None to remove + // scope) + let mut scope_mapping: HashMap> = HashMap::new(); + let mut merged_scopes = DisjointSet::::new(); + + // Phase 1: Walk instructions and collect scope relationships + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::MethodCall { property, .. } => { + let lvalue_scope = env.identifiers[instr.lvalue.identifier.0 as usize].scope; + let property_scope = env.identifiers[property.identifier.0 as usize].scope; + + match (lvalue_scope, property_scope) { + (Some(lvalue_sid), Some(property_sid)) => { + // Both have a scope: merge the scopes + merged_scopes.union(&[lvalue_sid, property_sid]); + } + (Some(lvalue_sid), None) => { + // Call has a scope but not the property: + // record that this property should be in this scope + scope_mapping.insert(property.identifier, Some(lvalue_sid)); + } + (None, Some(_)) => { + // Property has a scope but call doesn't: + // this property does not need a scope + scope_mapping.insert(property.identifier, None); + } + (None, None) => { + // Neither has a scope, nothing to do + } + } + } + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + // Recurse into inner functions + let func_id = lowered_func.func; + let mut inner_func = std::mem::replace( + &mut env.functions[func_id.0 as usize], + react_compiler_ssa::enter_ssa::placeholder_function(), + ); + align_method_call_scopes(&mut inner_func, env); + env.functions[func_id.0 as usize] = inner_func; + } + _ => {} + } + } + } + + // Phase 2: Merge scope ranges for unioned scopes. + // Use a HashMap to accumulate min/max across all scopes mapping to the same + // root, matching TS behavior where root.range is updated in-place during + // iteration. + let mut range_updates: HashMap = HashMap::new(); + + merged_scopes.for_each(|scope_id, root_id| { + if scope_id == root_id { + return; + } + let scope_range = env.scopes[scope_id.0 as usize].range.clone(); + let root_range = env.scopes[root_id.0 as usize].range.clone(); + + let entry = range_updates + .entry(root_id) + .or_insert_with(|| (root_range.start, root_range.end)); + entry.0 = EvaluationOrder(std::cmp::min(entry.0 .0, scope_range.start.0)); + entry.1 = EvaluationOrder(std::cmp::max(entry.1 .0, scope_range.end.0)); + }); + + for (root_id, (new_start, new_end)) in range_updates { + env.scopes[root_id.0 as usize].range.start = new_start; + env.scopes[root_id.0 as usize].range.end = new_end; + } + + // Phase 3: Apply scope mappings and merged scope reassignments + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let lvalue_id = func.instructions[instr_id.0 as usize].lvalue.identifier; + + if let Some(mapped_scope) = scope_mapping.get(&lvalue_id) { + env.identifiers[lvalue_id.0 as usize].scope = *mapped_scope; + } else if let Some(current_scope) = env.identifiers[lvalue_id.0 as usize].scope { + // TS: mergedScopes.find() returns null if not in the set + if let Some(merged) = merged_scopes.find_opt(current_scope) { + env.identifiers[lvalue_id.0 as usize].scope = Some(merged); + } + } + } + } +} diff --git a/crates/react_compiler_inference/src/align_object_method_scopes.rs b/crates/react_compiler_inference/src/align_object_method_scopes.rs new file mode 100644 index 000000000000..cfa2b71fb2b3 --- /dev/null +++ b/crates/react_compiler_inference/src/align_object_method_scopes.rs @@ -0,0 +1,153 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Aligns scopes of object method values to that of their enclosing object +//! expressions. To produce a well-formed JS program in Codegen, object methods +//! and object expressions must be in the same ReactiveBlock as object method +//! definitions must be inlined. +//! +//! Ported from TypeScript `src/ReactiveScopes/AlignObjectMethodScopes.ts`. + +use std::{ + cmp, + collections::{HashMap, HashSet}, +}; + +use react_compiler_hir::{ + environment::Environment, EvaluationOrder, HirFunction, IdentifierId, InstructionValue, + ObjectPropertyOrSpread, ScopeId, +}; +use react_compiler_utils::DisjointSet; + +// ============================================================================= +// findScopesToMerge +// ============================================================================= + +/// Identifies ObjectMethod lvalue identifiers and then finds ObjectExpression +/// instructions whose operands reference those methods. Returns a disjoint set +/// of scopes that must be merged. +fn find_scopes_to_merge(func: &HirFunction, env: &Environment) -> DisjointSet { + let mut object_method_decls: HashSet = HashSet::new(); + let mut merged_scopes = DisjointSet::::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::ObjectMethod { .. } => { + object_method_decls.insert(instr.lvalue.identifier); + } + InstructionValue::ObjectExpression { properties, .. } => { + for prop_or_spread in properties { + let operand_place = match prop_or_spread { + ObjectPropertyOrSpread::Property(prop) => &prop.place, + ObjectPropertyOrSpread::Spread(spread) => &spread.place, + }; + if object_method_decls.contains(&operand_place.identifier) { + let operand_scope = + env.identifiers[operand_place.identifier.0 as usize].scope; + let lvalue_scope = + env.identifiers[instr.lvalue.identifier.0 as usize].scope; + + // TS: CompilerError.invariant(operandScope != null && lvalueScope != + // null, ...) + let operand_sid = operand_scope.expect( + "Internal error: Expected all ObjectExpressions and ObjectMethods \ + to have non-null scope.", + ); + let lvalue_sid = lvalue_scope.expect( + "Internal error: Expected all ObjectExpressions and ObjectMethods \ + to have non-null scope.", + ); + merged_scopes.union(&[operand_sid, lvalue_sid]); + } + } + } + _ => {} + } + } + } + + merged_scopes +} + +// ============================================================================= +// Public API +// ============================================================================= + +/// Aligns object method scopes so that ObjectMethod values and their enclosing +/// ObjectExpression share the same scope. +/// +/// Corresponds to TS `alignObjectMethodScopes(fn: HIRFunction): void`. +pub fn align_object_method_scopes(func: &mut HirFunction, env: &mut Environment) { + // Handle inner functions first (TS recurses before processing the outer + // function) + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let func_id = lowered_func.func; + let mut inner_func = std::mem::replace( + &mut env.functions[func_id.0 as usize], + react_compiler_ssa::enter_ssa::placeholder_function(), + ); + align_object_method_scopes(&mut inner_func, env); + env.functions[func_id.0 as usize] = inner_func; + } + _ => {} + } + } + } + + let mut merged_scopes = find_scopes_to_merge(func, env); + + // Step 1: Merge affected scopes to their canonical root. + // Use a HashMap to accumulate min/max across all scopes mapping to the same + // root, matching TS behavior where root.range is updated in-place during + // iteration. + let mut range_updates: HashMap = HashMap::new(); + + merged_scopes.for_each(|scope_id, root_id| { + if scope_id == root_id { + return; + } + let scope_range = env.scopes[scope_id.0 as usize].range.clone(); + let root_range = env.scopes[root_id.0 as usize].range.clone(); + + let entry = range_updates + .entry(root_id) + .or_insert_with(|| (root_range.start, root_range.end)); + entry.0 = EvaluationOrder(cmp::min(entry.0 .0, scope_range.start.0)); + entry.1 = EvaluationOrder(cmp::max(entry.1 .0, scope_range.end.0)); + }); + + for (root_id, (new_start, new_end)) in range_updates { + env.scopes[root_id.0 as usize].range.start = new_start; + env.scopes[root_id.0 as usize].range.end = new_end; + } + + // Step 2: Repoint identifiers whose scopes were merged + // Build a map from old scope -> root scope for quick lookup + let mut scope_remap: HashMap = HashMap::new(); + merged_scopes.for_each(|scope_id, root_id| { + if scope_id != root_id { + scope_remap.insert(scope_id, root_id); + } + }); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let lvalue_id = func.instructions[instr_id.0 as usize].lvalue.identifier; + + if let Some(current_scope) = env.identifiers[lvalue_id.0 as usize].scope { + if let Some(&root) = scope_remap.get(¤t_scope) { + env.identifiers[lvalue_id.0 as usize].scope = Some(root); + } + } + } + } +} diff --git a/crates/react_compiler_inference/src/align_reactive_scopes_to_block_scopes_hir.rs b/crates/react_compiler_inference/src/align_reactive_scopes_to_block_scopes_hir.rs new file mode 100644 index 000000000000..af404a808600 --- /dev/null +++ b/crates/react_compiler_inference/src/align_reactive_scopes_to_block_scopes_hir.rs @@ -0,0 +1,321 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Aligns reactive scope boundaries to block scope boundaries in the HIR. +//! +//! Ported from TypeScript +//! `src/ReactiveScopes/AlignReactiveScopesToBlockScopesHIR.ts`. +//! +//! This is the 2nd of 4 passes that determine how to break a function into +//! discrete reactive scopes (independently memoizable units of code): +//! 1. InferReactiveScopeVariables (on HIR) determines operands that mutate +//! together and assigns them a unique reactive scope. +//! 2. AlignReactiveScopesToBlockScopes (this pass) aligns reactive scopes to +//! block scopes. +//! 3. MergeOverlappingReactiveScopes ensures scopes do not overlap. +//! 4. BuildReactiveBlocks groups the statements for each scope. +//! +//! Prior inference passes assign a reactive scope to each operand, but the +//! ranges of these scopes are based on specific instructions at arbitrary +//! points in the control-flow graph. However, to codegen blocks around the +//! instructions in each scope, the scopes must be aligned to block-scope +//! boundaries — we can't memoize half of a loop! + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + environment::Environment, + visitors, + visitors::{ + each_instruction_lvalue_ids, each_instruction_value_operand_ids, each_terminal_operand_ids, + }, + BlockId, BlockKind, EvaluationOrder, HirFunction, IdentifierId, MutableRange, ScopeId, + Terminal, +}; + +// ============================================================================= +// ValueBlockNode — stores the valueRange for scope alignment in value blocks +// ============================================================================= + +/// Tracks the value range for a value block. The `children` field from the TS +/// implementation is only used for debug output and is omitted here. +#[derive(Clone)] +struct ValueBlockNode { + value_range: MutableRange, +} + +/// Returns all block IDs referenced by a terminal, including both direct +/// successors and fallthrough. +fn all_terminal_block_ids(terminal: &Terminal) -> Vec { + visitors::each_terminal_all_successors(terminal) +} + +// ============================================================================= +// Helper: get the first EvaluationOrder in a block +// ============================================================================= + +fn block_first_id(func: &HirFunction, block_id: BlockId) -> EvaluationOrder { + let block = func.body.blocks.get(&block_id).unwrap(); + if !block.instructions.is_empty() { + func.instructions[block.instructions[0].0 as usize].id + } else { + block.terminal.evaluation_order() + } +} + +// ============================================================================= +// BlockFallthroughRange +// ============================================================================= + +#[derive(Clone)] +struct BlockFallthroughRange { + fallthrough: BlockId, + range: MutableRange, +} + +// ============================================================================= +// Public API +// ============================================================================= + +/// Aligns reactive scope boundaries to block scope boundaries in the HIR. +/// +/// This pass updates reactive scope boundaries to align to control flow +/// boundaries. For example, if a scope ends partway through an if consequent, +/// the scope is extended to the end of the consequent block. +pub fn align_reactive_scopes_to_block_scopes_hir(func: &mut HirFunction, env: &mut Environment) { + let mut active_block_fallthrough_ranges: Vec = Vec::new(); + let mut active_scopes: HashSet = HashSet::new(); + let mut seen: HashSet = HashSet::new(); + let mut value_block_nodes: HashMap = HashMap::new(); + + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + for &block_id in &block_ids { + let starting_id = block_first_id(func, block_id); + + // Retain only active scopes whose range.end > startingId + active_scopes.retain(|&scope_id| env.scopes[scope_id.0 as usize].range.end > starting_id); + + // Check if we've reached a fallthrough block + if let Some(top) = active_block_fallthrough_ranges.last().cloned() { + if top.fallthrough == block_id { + active_block_fallthrough_ranges.pop(); + // All active scopes overlap this block-fallthrough range; + // extend their start to include the range start. + for &scope_id in &active_scopes { + let scope = &mut env.scopes[scope_id.0 as usize]; + scope.range.start = std::cmp::min(scope.range.start, top.range.start); + } + } + } + + let node = value_block_nodes.get(&block_id).cloned(); + + // Visit instruction lvalues and operands + let block = func.body.blocks.get(&block_id).unwrap(); + let instr_ids: Vec = + block.instructions.iter().copied().collect(); + for &instr_id in &instr_ids { + let instr = &func.instructions[instr_id.0 as usize]; + let eval_order = instr.id; + + let lvalue_ids = each_instruction_lvalue_ids(instr); + for lvalue_id in lvalue_ids { + record_place_id( + eval_order, + lvalue_id, + &node, + env, + &mut active_scopes, + &mut seen, + ); + } + + let operand_ids = each_instruction_value_operand_ids(&instr.value, env); + for operand_id in operand_ids { + record_place_id( + eval_order, + operand_id, + &node, + env, + &mut active_scopes, + &mut seen, + ); + } + } + + // Visit terminal operands + let block = func.body.blocks.get(&block_id).unwrap(); + let terminal_eval_order = block.terminal.evaluation_order(); + let terminal_operand_ids = each_terminal_operand_ids(&block.terminal); + for operand_id in terminal_operand_ids { + record_place_id( + terminal_eval_order, + operand_id, + &node, + env, + &mut active_scopes, + &mut seen, + ); + } + + let block = func.body.blocks.get(&block_id).unwrap(); + let terminal = &block.terminal; + let fallthrough = visitors::terminal_fallthrough(terminal); + let is_branch = matches!(terminal, Terminal::Branch { .. }); + let is_goto = match terminal { + Terminal::Goto { block, .. } => Some(*block), + _ => None, + }; + let is_ternary_logical_optional = matches!( + terminal, + Terminal::Ternary { .. } | Terminal::Logical { .. } | Terminal::Optional { .. } + ); + let all_successors = all_terminal_block_ids(terminal); + + // Handle fallthrough logic + if let Some(ft) = fallthrough { + if !is_branch { + let next_id = block_first_id(func, ft); + + for &scope_id in &active_scopes { + let scope = &mut env.scopes[scope_id.0 as usize]; + if scope.range.end > terminal_eval_order { + scope.range.end = std::cmp::max(scope.range.end, next_id); + } + } + + active_block_fallthrough_ranges.push(BlockFallthroughRange { + fallthrough: ft, + range: MutableRange { + start: terminal_eval_order, + end: next_id, + }, + }); + + assert!( + !value_block_nodes.contains_key(&ft), + "Expect hir blocks to have unique fallthroughs" + ); + if let Some(n) = &node { + value_block_nodes.insert(ft, n.clone()); + } + } + } else if let Some(goto_block) = is_goto { + // Handle goto to label + let start_pos = active_block_fallthrough_ranges + .iter() + .position(|r| r.fallthrough == goto_block); + let top_idx = if active_block_fallthrough_ranges.is_empty() { + None + } else { + Some(active_block_fallthrough_ranges.len() - 1) + }; + if let Some(pos) = start_pos { + if top_idx != Some(pos) { + let start_range = active_block_fallthrough_ranges[pos].clone(); + let first_id = block_first_id(func, start_range.fallthrough); + + for &scope_id in &active_scopes { + let scope = &mut env.scopes[scope_id.0 as usize]; + if scope.range.end <= terminal_eval_order { + continue; + } + scope.range.start = + std::cmp::min(start_range.range.start, scope.range.start); + scope.range.end = std::cmp::max(first_id, scope.range.end); + } + } + } + } + + // Visit all successors to set up value block nodes + for successor in all_successors { + if value_block_nodes.contains_key(&successor) { + continue; + } + + let successor_block = func.body.blocks.get(&successor).unwrap(); + if successor_block.kind == BlockKind::Block || successor_block.kind == BlockKind::Catch + { + // Block or catch kind: don't create a value block node + } else if node.is_none() || is_ternary_logical_optional { + // Create a new node when transitioning non-value -> value, + // or for ternary/logical/optional terminals. + let value_range = if node.is_none() { + // Transition from block -> value block + let ft = fallthrough.expect("Expected a fallthrough for value block"); + let next_id = block_first_id(func, ft); + MutableRange { + start: terminal_eval_order, + end: next_id, + } + } else { + // Value -> value transition (ternary/logical/optional): reuse range + node.as_ref().unwrap().value_range.clone() + }; + + value_block_nodes.insert(successor, ValueBlockNode { value_range }); + } else { + // Value -> value block transition: reuse the node + if let Some(n) = &node { + value_block_nodes.insert(successor, n.clone()); + } + } + } + } + + // Sync identifier mutable_range with their scope's range. + // In TS, identifier.mutableRange and scope.range are the same shared object, + // so modifications to scope.range are automatically visible through the + // identifier. In Rust they are separate copies, so we must explicitly sync. + for ident in &mut env.identifiers { + if let Some(scope_id) = ident.scope { + let scope_range = &env.scopes[scope_id.0 as usize].range; + ident.mutable_range.start = scope_range.start; + ident.mutable_range.end = scope_range.end; + } + } +} + +/// Records a place's scope as active and adjusts scope ranges for value blocks. +/// +/// Mirrors TS `recordPlace(id, place, node)`. +fn record_place_id( + id: EvaluationOrder, + identifier_id: IdentifierId, + node: &Option, + env: &mut Environment, + active_scopes: &mut HashSet, + seen: &mut HashSet, +) { + // Get the scope for this identifier, if active at this instruction + let scope_id = match env.identifiers[identifier_id.0 as usize].scope { + Some(scope_id) => { + let scope = &env.scopes[scope_id.0 as usize]; + if id >= scope.range.start && id < scope.range.end { + Some(scope_id) + } else { + None + } + } + None => None, + }; + + if let Some(scope_id) = scope_id { + active_scopes.insert(scope_id); + + if seen.contains(&scope_id) { + return; + } + seen.insert(scope_id); + + if let Some(n) = node { + let scope = &mut env.scopes[scope_id.0 as usize]; + scope.range.start = std::cmp::min(n.value_range.start, scope.range.start); + scope.range.end = std::cmp::max(n.value_range.end, scope.range.end); + } + } +} diff --git a/crates/react_compiler_inference/src/analyse_functions.rs b/crates/react_compiler_inference/src/analyse_functions.rs new file mode 100644 index 000000000000..bfcf3af58957 --- /dev/null +++ b/crates/react_compiler_inference/src/analyse_functions.rs @@ -0,0 +1,221 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Recursively analyzes nested function expressions and object methods to infer +//! their aliasing effect signatures. +//! +//! Ported from TypeScript `src/Inference/AnalyseFunctions.ts`. +//! +//! Runs inferMutationAliasingEffects, deadCodeElimination, +//! inferMutationAliasingRanges, rewriteInstructionKindsBasedOnReassignment, +//! and inferReactiveScopeVariables on each inner function. + +use std::collections::HashSet; + +use indexmap::IndexMap; +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, AliasingEffect, BlockId, Effect, EvaluationOrder, FunctionId, + HirFunction, IdentifierId, InstructionValue, MutableRange, Place, ReactFunctionType, HIR, +}; + +/// Analyse all nested function expressions and object methods in `func`. +/// +/// For each inner function found, runs `lower_with_mutation_aliasing` to infer +/// its aliasing effects, then resets context variable mutable ranges. +/// +/// The optional `debug_logger` callback is invoked after processing each inner +/// function, receiving `(&HirFunction, &Environment)` so the caller can produce +/// debug output. This mirrors the TS `fn.env.logger?.debugLogIRs` call inside +/// `lowerWithMutationAliasing`. +/// +/// Corresponds to TS `analyseFunctions(func: HIRFunction): void`. +pub fn analyse_functions( + func: &mut HirFunction, + env: &mut Environment, + debug_logger: &mut F, +) -> Result<(), CompilerDiagnostic> +where + F: FnMut(&HirFunction, &Environment), +{ + // Collect FunctionIds from FunctionExpression/ObjectMethod instructions. + // We collect first to avoid borrow conflicts with env.functions. + let mut inner_func_ids: Vec = Vec::new(); + for (_block_id, block) in &func.body.blocks { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + inner_func_ids.push(lowered_func.func); + } + _ => {} + } + } + } + + // Process each inner function + for func_id in inner_func_ids { + // Take the inner function out of the arena to avoid borrow conflicts + let mut inner_func = std::mem::replace( + &mut env.functions[func_id.0 as usize], + placeholder_function(), + ); + + lower_with_mutation_aliasing(&mut inner_func, env, debug_logger)?; + + // If an invariant error was recorded, put the function back and stop processing + if env.has_invariant_errors() { + env.functions[func_id.0 as usize] = inner_func; + return Ok(()); + } + + // Reset mutable range for outer inferMutationAliasingEffects. + // + // NOTE: inferReactiveScopeVariables makes identifiers in the scope + // point to the *same* mutableRange instance (in TS). In Rust, scopes + // are stored in an arena, so we reset both the identifier's range + // and clear its scope. + for operand in &inner_func.context { + let ident = &mut env.identifiers[operand.identifier.0 as usize]; + ident.mutable_range = MutableRange { + start: EvaluationOrder(0), + end: EvaluationOrder(0), + }; + ident.scope = None; + } + + // Put the function back + env.functions[func_id.0 as usize] = inner_func; + } + + Ok(()) +} + +/// Run mutation/aliasing inference on an inner function. +/// +/// Corresponds to TS `lowerWithMutationAliasing(fn: HIRFunction): void`. +fn lower_with_mutation_aliasing( + func: &mut HirFunction, + env: &mut Environment, + debug_logger: &mut F, +) -> Result<(), CompilerDiagnostic> +where + F: FnMut(&HirFunction, &Environment), +{ + // Phase 1: Recursively analyse nested functions first (depth-first) + analyse_functions(func, env, debug_logger)?; + + // inferMutationAliasingEffects on the inner function + crate::infer_mutation_aliasing_effects::infer_mutation_aliasing_effects(func, env, true)?; + + // Check for invariant errors (e.g., uninitialized value kind) + // In TS, these throw from within inferMutationAliasingEffects, aborting + // the rest of the function processing. + if env.has_invariant_errors() { + return Ok(()); + } + + // deadCodeElimination for inner functions + react_compiler_optimization::dead_code_elimination(func, env); + + // inferMutationAliasingRanges — returns the externally-visible function effects + let function_effects = + crate::infer_mutation_aliasing_ranges::infer_mutation_aliasing_ranges(func, env, true)?; + + // rewriteInstructionKindsBasedOnReassignment + if let Err(err) = react_compiler_ssa::rewrite_instruction_kinds_based_on_reassignment(func, env) + { + env.errors.merge(err); + return Ok(()); + } + + // inferReactiveScopeVariables on the inner function + crate::infer_reactive_scope_variables::infer_reactive_scope_variables(func, env)?; + + func.aliasing_effects = Some(function_effects.clone()); + + // Phase 2: Populate the Effect of each context variable to use in inferring + // the outer function. Corresponds to TS Phase 2 in lowerWithMutationAliasing. + let mut captured_or_mutated: HashSet = HashSet::new(); + for effect in &function_effects { + match effect { + AliasingEffect::Assign { from, .. } + | AliasingEffect::Alias { from, .. } + | AliasingEffect::Capture { from, .. } + | AliasingEffect::CreateFrom { from, .. } + | AliasingEffect::MaybeAlias { from, .. } => { + captured_or_mutated.insert(from.identifier); + } + AliasingEffect::Mutate { value, .. } + | AliasingEffect::MutateConditionally { value } + | AliasingEffect::MutateTransitive { value } + | AliasingEffect::MutateTransitiveConditionally { value } => { + captured_or_mutated.insert(value.identifier); + } + AliasingEffect::Impure { .. } + | AliasingEffect::Render { .. } + | AliasingEffect::MutateFrozen { .. } + | AliasingEffect::MutateGlobal { .. } + | AliasingEffect::CreateFunction { .. } + | AliasingEffect::Create { .. } + | AliasingEffect::Freeze { .. } + | AliasingEffect::ImmutableCapture { .. } => { + // no-op + } + AliasingEffect::Apply { .. } => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "[AnalyzeFunctions] Expected Apply effects to be replaced with more precise \ + effects", + None, + )); + } + } + } + + for operand in &mut func.context { + if captured_or_mutated.contains(&operand.identifier) || operand.effect == Effect::Capture { + operand.effect = Effect::Capture; + } else { + operand.effect = Effect::Read; + } + } + + // Log the inner function's state (mirrors TS: fn.env.logger?.debugLogIRs) + debug_logger(func, env); + + Ok(()) +} + +/// Create a placeholder HirFunction for temporarily swapping an inner function +/// out of `env.functions` via `std::mem::replace`. The placeholder is never +/// read — the real function is swapped back immediately after processing. +fn placeholder_function() -> HirFunction { + HirFunction { + loc: None, + id: None, + name_hint: None, + fn_type: ReactFunctionType::Other, + params: Vec::new(), + return_type_annotation: None, + returns: Place { + identifier: IdentifierId(0), + effect: Effect::Unknown, + reactive: false, + loc: None, + }, + context: Vec::new(), + body: HIR { + entry: BlockId(0), + blocks: IndexMap::new(), + }, + instructions: Vec::new(), + generator: false, + is_async: false, + directives: Vec::new(), + aliasing_effects: None, + } +} diff --git a/crates/react_compiler_inference/src/build_reactive_scope_terminals_hir.rs b/crates/react_compiler_inference/src/build_reactive_scope_terminals_hir.rs new file mode 100644 index 000000000000..35331dba7c39 --- /dev/null +++ b/crates/react_compiler_inference/src/build_reactive_scope_terminals_hir.rs @@ -0,0 +1,406 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Builds reactive scope terminals in the HIR. +//! +//! Given a function whose reactive scope ranges have been correctly aligned and +//! merged, this pass rewrites blocks to introduce ReactiveScopeTerminals and +//! their fallthrough blocks. +//! +//! Ported from TypeScript `src/HIR/BuildReactiveScopeTerminalsHIR.ts`. + +use std::collections::{HashMap, HashSet}; + +use indexmap::IndexMap; +use react_compiler_hir::{ + environment::Environment, + visitors::{ + each_instruction_lvalue_ids, each_instruction_operand_ids, each_terminal_operand_ids, + }, + BasicBlock, BlockId, EvaluationOrder, GotoVariant, HirFunction, IdentifierId, ScopeId, + Terminal, +}; +use react_compiler_lowering::{ + get_reverse_postordered_blocks, mark_instruction_ids, mark_predecessors, +}; + +// ============================================================================= +// getScopes +// ============================================================================= + +/// Collect all unique scopes from places in the function that have non-empty +/// ranges. Corresponds to TS `getScopes(fn)`. +fn get_scopes(func: &HirFunction, env: &Environment) -> Vec { + let mut scope_ids: HashSet = HashSet::new(); + + let mut visit_place = |identifier_id: IdentifierId| { + if let Some(scope_id) = env.identifiers[identifier_id.0 as usize].scope { + let range = &env.scopes[scope_id.0 as usize].range; + if range.start != range.end { + scope_ids.insert(scope_id); + } + } + }; + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + // lvalues + for id in each_instruction_lvalue_ids(instr) { + visit_place(id); + } + // operands + for id in each_instruction_operand_ids(instr, env) { + visit_place(id); + } + } + // terminal operands + for id in each_terminal_operand_ids(&block.terminal) { + visit_place(id); + } + } + + scope_ids.into_iter().collect() +} + +// ============================================================================= +// TerminalRewriteInfo +// ============================================================================= + +enum TerminalRewriteInfo { + StartScope { + block_id: BlockId, + fallthrough_id: BlockId, + instr_id: EvaluationOrder, + scope_id: ScopeId, + }, + EndScope { + instr_id: EvaluationOrder, + fallthrough_id: BlockId, + }, +} + +impl TerminalRewriteInfo { + fn instr_id(&self) -> EvaluationOrder { + match self { + TerminalRewriteInfo::StartScope { instr_id, .. } => *instr_id, + TerminalRewriteInfo::EndScope { instr_id, .. } => *instr_id, + } + } +} + +// ============================================================================= +// collectScopeRewrites +// ============================================================================= + +/// Collect all scope rewrites by traversing scopes in pre-order. +fn collect_scope_rewrites(func: &HirFunction, env: &mut Environment) -> Vec { + let scope_ids = get_scopes(func, env); + + // Sort: ascending by start, descending by end for ties + let mut items: Vec = scope_ids; + items.sort_by(|a, b| { + let a_range = &env.scopes[a.0 as usize].range; + let b_range = &env.scopes[b.0 as usize].range; + let start_diff = a_range.start.0.cmp(&b_range.start.0); + if start_diff != std::cmp::Ordering::Equal { + return start_diff; + } + b_range.end.0.cmp(&a_range.end.0) + }); + + let mut rewrites: Vec = Vec::new(); + let mut fallthroughs: HashMap = HashMap::new(); + let mut active_items: Vec = Vec::new(); + + for i in 0..items.len() { + let curr = items[i]; + let curr_start = env.scopes[curr.0 as usize].range.start; + let curr_end = env.scopes[curr.0 as usize].range.end; + + // Pop active items that are disjoint with current + let mut j = active_items.len(); + while j > 0 { + j -= 1; + let maybe_parent = active_items[j]; + let parent_end = env.scopes[maybe_parent.0 as usize].range.end; + let disjoint = curr_start >= parent_end; + let nested = curr_end <= parent_end; + assert!( + disjoint || nested, + "Invalid nesting in program blocks or scopes" + ); + if disjoint { + // Exit this scope + let fallthrough_id = *fallthroughs + .get(&maybe_parent) + .expect("Expected scope to exist"); + let end_instr_id = env.scopes[maybe_parent.0 as usize].range.end; + rewrites.push(TerminalRewriteInfo::EndScope { + instr_id: end_instr_id, + fallthrough_id, + }); + active_items.truncate(j); + } else { + break; + } + } + + // Enter scope + let block_id = env.next_block_id(); + let fallthrough_id = env.next_block_id(); + let start_instr_id = env.scopes[curr.0 as usize].range.start; + rewrites.push(TerminalRewriteInfo::StartScope { + block_id, + fallthrough_id, + instr_id: start_instr_id, + scope_id: curr, + }); + fallthroughs.insert(curr, fallthrough_id); + active_items.push(curr); + } + + // Exit remaining active items + while let Some(curr) = active_items.pop() { + let fallthrough_id = *fallthroughs.get(&curr).expect("Expected scope to exist"); + let end_instr_id = env.scopes[curr.0 as usize].range.end; + rewrites.push(TerminalRewriteInfo::EndScope { + instr_id: end_instr_id, + fallthrough_id, + }); + } + + rewrites +} + +// ============================================================================= +// handleRewrite +// ============================================================================= + +struct RewriteContext { + next_block_id: BlockId, + next_preds: Vec, + instr_slice_idx: usize, + rewrites: Vec, +} + +fn handle_rewrite( + terminal_info: &TerminalRewriteInfo, + idx: usize, + source_block: &BasicBlock, + context: &mut RewriteContext, +) { + let terminal: Terminal = match terminal_info { + TerminalRewriteInfo::StartScope { + block_id, + fallthrough_id, + instr_id, + scope_id, + } => Terminal::Scope { + fallthrough: *fallthrough_id, + block: *block_id, + scope: *scope_id, + id: *instr_id, + loc: None, + }, + TerminalRewriteInfo::EndScope { + instr_id, + fallthrough_id, + } => Terminal::Goto { + variant: GotoVariant::Break, + block: *fallthrough_id, + id: *instr_id, + loc: None, + }, + }; + + let curr_block_id = context.next_block_id; + let mut preds = indexmap::IndexSet::new(); + for &p in &context.next_preds { + preds.insert(p); + } + + context.rewrites.push(BasicBlock { + kind: source_block.kind, + id: curr_block_id, + instructions: source_block.instructions[context.instr_slice_idx..idx].to_vec(), + preds, + // Only the first rewrite should reuse source block phis + phis: if context.rewrites.is_empty() { + source_block.phis.clone() + } else { + Vec::new() + }, + terminal, + }); + + context.next_preds = vec![curr_block_id]; + context.next_block_id = match terminal_info { + TerminalRewriteInfo::StartScope { block_id, .. } => *block_id, + TerminalRewriteInfo::EndScope { fallthrough_id, .. } => *fallthrough_id, + }; + context.instr_slice_idx = idx; +} + +// ============================================================================= +// Public API +// ============================================================================= + +/// Builds reactive scope terminals in the HIR. +/// +/// This pass assumes that all program blocks are properly nested with respect +/// to fallthroughs. Given a function whose reactive scope ranges have been +/// correctly aligned and merged, this pass rewrites blocks to introduce +/// ReactiveScopeTerminals and their fallthrough blocks. +pub fn build_reactive_scope_terminals_hir(func: &mut HirFunction, env: &mut Environment) { + // Step 1: Collect rewrites + let mut queued_rewrites = collect_scope_rewrites(func, env); + + // Step 2: Apply rewrites by splitting blocks + let mut rewritten_final_blocks: HashMap = HashMap::new(); + let mut next_blocks: IndexMap = IndexMap::new(); + + // Reverse so we can pop from the end while traversing in ascending order + queued_rewrites.reverse(); + + for (_block_id, block) in &func.body.blocks { + let preds_vec: Vec = block.preds.iter().copied().collect(); + let mut context = RewriteContext { + next_block_id: block.id, + rewrites: Vec::new(), + next_preds: preds_vec, + instr_slice_idx: 0, + }; + + // Handle queued terminal rewrites at their nearest instruction ID + for i in 0..block.instructions.len() + 1 { + let instr_id = if i < block.instructions.len() { + let instr_idx = block.instructions[i]; + func.instructions[instr_idx.0 as usize].id + } else { + block.terminal.evaluation_order() + }; + + while let Some(rewrite) = queued_rewrites.last() { + if rewrite.instr_id() <= instr_id { + // Need to pop before calling handle_rewrite + let rewrite = queued_rewrites.pop().unwrap(); + handle_rewrite(&rewrite, i, block, &mut context); + } else { + break; + } + } + } + + if !context.rewrites.is_empty() { + let mut final_preds = indexmap::IndexSet::new(); + for &p in &context.next_preds { + final_preds.insert(p); + } + let final_block = BasicBlock { + id: context.next_block_id, + kind: block.kind, + preds: final_preds, + terminal: block.terminal.clone(), + instructions: block.instructions[context.instr_slice_idx..].to_vec(), + phis: Vec::new(), + }; + let final_block_id = final_block.id; + context.rewrites.push(final_block); + for b in context.rewrites { + next_blocks.insert(b.id, b); + } + rewritten_final_blocks.insert(block.id, final_block_id); + } else { + next_blocks.insert(block.id, block.clone()); + } + } + + func.body.blocks = next_blocks; + + // Step 3: Repoint phis when they refer to a rewritten block + for block in func.body.blocks.values_mut() { + for phi in &mut block.phis { + let updates: Vec<(BlockId, BlockId)> = phi + .operands + .keys() + .filter_map(|original_id| { + rewritten_final_blocks + .get(original_id) + .map(|new_id| (*original_id, *new_id)) + }) + .collect(); + for (old_id, new_id) in updates { + if let Some(value) = phi.operands.shift_remove(&old_id) { + phi.operands.insert(new_id, value); + } + } + } + } + + // Step 4: Fixup HIR to restore RPO, correct predecessors, renumber instructions + func.body.blocks = get_reverse_postordered_blocks(&func.body, &func.instructions); + mark_predecessors(&mut func.body); + mark_instruction_ids(&mut func.body, &mut func.instructions); + + // Step 5: Fix scope and identifier ranges to account for renumbered + // instructions + fix_scope_and_identifier_ranges(func, env); +} + +/// Fix scope ranges after instruction renumbering. +/// Scope ranges should always align to start at the 'scope' terminal +/// and end at the first instruction of the fallthrough block. +/// +/// In TS, `identifier.mutableRange` and `scope.range` are the same object +/// reference (after InferReactiveScopeVariables). When scope.range is updated, +/// all identifiers with that scope automatically see the new range. +/// BUT: after MergeOverlappingReactiveScopesHIR, repointed identifiers have +/// mutableRange pointing to the OLD scope's range, NOT the root scope's range. +/// So only identifiers whose mutableRange matches their scope's pre-renumbering +/// range should be updated. +/// +/// Corresponds to TS `fixScopeAndIdentifierRanges`. +fn fix_scope_and_identifier_ranges(func: &HirFunction, env: &mut Environment) { + for (_block_id, block) in &func.body.blocks { + match &block.terminal { + Terminal::Scope { + fallthrough, + scope, + id, + .. + } + | Terminal::PrunedScope { + fallthrough, + scope, + id, + .. + } => { + let fallthrough_block = func.body.blocks.get(fallthrough).unwrap(); + let first_id = if !fallthrough_block.instructions.is_empty() { + func.instructions[fallthrough_block.instructions[0].0 as usize].id + } else { + fallthrough_block.terminal.evaluation_order() + }; + env.scopes[scope.0 as usize].range.start = *id; + env.scopes[scope.0 as usize].range.end = first_id; + } + _ => {} + } + } + + // Sync identifier mutable ranges with their scope ranges. + // In TS, identifier.mutableRange IS scope.range (shared object reference). + // When fixScopeAndIdentifierRanges updates scope.range, all identifiers + // whose mutableRange points to that scope automatically see the update. + // In Rust, we must explicitly copy scope range to identifier mutable_range. + for ident in &mut env.identifiers { + if let Some(scope_id) = ident.scope { + let scope_range = &env.scopes[scope_id.0 as usize].range; + ident.mutable_range.start = scope_range.start; + ident.mutable_range.end = scope_range.end; + } + } +} diff --git a/crates/react_compiler_inference/src/flatten_reactive_loops_hir.rs b/crates/react_compiler_inference/src/flatten_reactive_loops_hir.rs new file mode 100644 index 000000000000..0ee8246635f0 --- /dev/null +++ b/crates/react_compiler_inference/src/flatten_reactive_loops_hir.rs @@ -0,0 +1,66 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Prunes any reactive scopes that are within a loop (for, while, etc). We +//! don't yet support memoization within loops because this would require an +//! extra layer of reconciliation (plus a way to identify values across runs, +//! similar to how we use `key` in JSX for lists). Eventually we may integrate +//! more deeply into the runtime so that we can do a single level +//! of reconciliation, but for now we've found it's sufficient to memoize +//! *around* the loop. +//! +//! Analogous to TS `ReactiveScopes/FlattenReactiveLoopsHIR.ts`. + +use react_compiler_hir::{BlockId, HirFunction, Terminal}; + +/// Flattens reactive scopes that are inside loops by converting `Scope` +/// terminals to `PrunedScope` terminals. +pub fn flatten_reactive_loops_hir(func: &mut HirFunction) { + let mut active_loops: Vec = Vec::new(); + + // Collect block ids in iteration order so we can iterate while mutating + // terminals + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + for block_id in block_ids { + // Remove this block from active loops (matching TS retainWhere) + active_loops.retain(|id| *id != block_id); + + let block = &func.body.blocks[&block_id]; + let terminal = &block.terminal; + + match terminal { + Terminal::DoWhile { fallthrough, .. } + | Terminal::For { fallthrough, .. } + | Terminal::ForIn { fallthrough, .. } + | Terminal::ForOf { fallthrough, .. } + | Terminal::While { fallthrough, .. } => { + active_loops.push(*fallthrough); + } + Terminal::Scope { + block, + fallthrough, + scope, + id, + loc, + } => { + if !active_loops.is_empty() { + let new_terminal = Terminal::PrunedScope { + block: *block, + fallthrough: *fallthrough, + scope: *scope, + id: *id, + loc: *loc, + }; + // We need to drop the borrow and reborrow mutably + let block_mut = func.body.blocks.get_mut(&block_id).unwrap(); + block_mut.terminal = new_terminal; + } + } + // All other terminal kinds: no action needed + _ => {} + } + } +} diff --git a/crates/react_compiler_inference/src/flatten_scopes_with_hooks_or_use_hir.rs b/crates/react_compiler_inference/src/flatten_scopes_with_hooks_or_use_hir.rs new file mode 100644 index 000000000000..12b149e95aa0 --- /dev/null +++ b/crates/react_compiler_inference/src/flatten_scopes_with_hooks_or_use_hir.rs @@ -0,0 +1,151 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! For simplicity the majority of compiler passes do not treat hooks specially. +//! However, hooks are different from regular functions in two key ways: +//! - They can introduce reactivity even when their arguments are non-reactive +//! (accounted for in InferReactivePlaces) +//! - They cannot be called conditionally +//! +//! The `use` operator is similar: +//! - It can access context, and therefore introduce reactivity +//! - It can be called conditionally, but _it must be called if the component +//! needs the return value_. This is because React uses the fact that use was +//! called to remember that the component needs the value, and that changes to +//! the input should invalidate the component itself. +//! +//! This pass accounts for the "can't call conditionally" aspect of both hooks +//! and use. Though the reasoning is slightly different for each, the result is +//! that we can't memoize scopes that call hooks or use since this would make +//! them called conditionally in the output. +//! +//! The pass finds and removes any scopes that transitively contain a hook or +//! use call. By running all the reactive scope inference first, agnostic of +//! hooks, we know that the reactive scopes accurately describe the set of +//! values which "construct together", and remove _all_ that memoization in +//! order to ensure the hook call does not inadvertently become conditional. +//! +//! Analogous to TS `ReactiveScopes/FlattenScopesWithHooksOrUseHIR.ts`. + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, BlockId, HirFunction, InstructionValue, Terminal, Type, +}; + +/// Flattens reactive scopes that contain hook calls or `use()` calls. +/// +/// Hooks and `use` must be called unconditionally, so any reactive scope +/// containing such a call must be flattened to avoid making the call +/// conditional. +pub fn flatten_scopes_with_hooks_or_use_hir( + func: &mut HirFunction, + env: &Environment, +) -> Result<(), CompilerDiagnostic> { + let mut active_scopes: Vec = Vec::new(); + let mut prune: Vec = Vec::new(); + + // Collect block ids to allow mutation during iteration + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + for block_id in &block_ids { + // Remove scopes whose fallthrough matches this block + active_scopes.retain(|scope| scope.fallthrough != *block_id); + + let block = &func.body.blocks[block_id]; + + // Check instructions for hook or use calls + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::CallExpression { callee, .. } => { + let callee_ty = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if is_hook_or_use(env, callee_ty)? { + // All active scopes must be pruned + prune.extend(active_scopes.iter().map(|s| s.block)); + active_scopes.clear(); + } + } + InstructionValue::MethodCall { property, .. } => { + let property_ty = &env.types + [env.identifiers[property.identifier.0 as usize].type_.0 as usize]; + if is_hook_or_use(env, property_ty)? { + prune.extend(active_scopes.iter().map(|s| s.block)); + active_scopes.clear(); + } + } + _ => {} + } + } + + // Track scope terminals + if let Terminal::Scope { fallthrough, .. } = &block.terminal { + active_scopes.push(ActiveScope { + block: *block_id, + fallthrough: *fallthrough, + }); + } + } + + // Apply pruning: convert Scope terminals to Label or PrunedScope + for id in prune { + let block = &func.body.blocks[&id]; + let terminal = &block.terminal; + + let (scope_block, fallthrough, eval_id, loc, scope) = match terminal { + Terminal::Scope { + block, + fallthrough, + id, + loc, + scope, + } => (*block, *fallthrough, *id, *loc, *scope), + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Expected block bb{} to end in a scope terminal", id.0), + None, + )); + } + }; + + // Check if the scope body is a single-instruction block that goes directly + // to fallthrough — if so, use Label instead of PrunedScope + let body = &func.body.blocks[&scope_block]; + let new_terminal = if body.instructions.len() == 1 + && matches!(&body.terminal, Terminal::Goto { block, .. } if *block == fallthrough) + { + // This was a scope just for a hook call, which doesn't need memoization. + // Flatten it away. We rely on PruneUnusedLabels to do the actual flattening. + Terminal::Label { + block: scope_block, + fallthrough, + id: eval_id, + loc, + } + } else { + Terminal::PrunedScope { + block: scope_block, + fallthrough, + scope, + id: eval_id, + loc, + } + }; + + let block_mut = func.body.blocks.get_mut(&id).unwrap(); + block_mut.terminal = new_terminal; + } + Ok(()) +} + +struct ActiveScope { + block: BlockId, + fallthrough: BlockId, +} + +fn is_hook_or_use(env: &Environment, ty: &Type) -> Result { + Ok(env.get_hook_kind_for_type(ty)?.is_some() || react_compiler_hir::is_use_operator_type(ty)) +} diff --git a/crates/react_compiler_inference/src/infer_mutation_aliasing_effects.rs b/crates/react_compiler_inference/src/infer_mutation_aliasing_effects.rs new file mode 100644 index 000000000000..8f1cc020fa1b --- /dev/null +++ b/crates/react_compiler_inference/src/infer_mutation_aliasing_effects.rs @@ -0,0 +1,3690 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Infers the mutation/aliasing effects for instructions and terminals. +//! +//! Ported from TypeScript `src/Inference/InferMutationAliasingEffects.ts`. +//! +//! This pass uses abstract interpretation to compute effects describing +//! creation, aliasing, mutation, freezing, and error conditions for each +//! instruction and terminal in the HIR. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, + object_shape::{ + FunctionSignature, HookKind, BUILT_IN_ARRAY_ID, BUILT_IN_MAP_ID, BUILT_IN_SET_ID, + }, + type_config::{ValueKind, ValueReason}, + visitors, AliasingEffect, AliasingSignature, BlockId, DeclarationId, Effect, FunctionId, + HirFunction, IdentifierId, InstructionKind, InstructionValue, MutationReason, ParamPattern, + Place, PlaceOrSpread, PlaceOrSpreadOrHole, ReactFunctionType, SourceLocation, Type, +}; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Infers mutation/aliasing effects for all instructions and terminals in +/// `func`. +/// +/// Corresponds to TS `inferMutationAliasingEffects(fn, +/// {isFunctionExpression})`. +pub fn infer_mutation_aliasing_effects( + func: &mut HirFunction, + env: &mut Environment, + is_function_expression: bool, +) -> Result<(), CompilerDiagnostic> { + let mut initial_state = InferenceState::empty(env, is_function_expression); + + // Map of blocks to the last (merged) incoming state that was processed + let mut states_by_block: HashMap = HashMap::new(); + + // Initialize context variables + for ctx_place in &func.context { + let value_id = ValueId::new(); + initial_state.initialize( + value_id, + AbstractValue { + kind: ValueKind::Context, + reason: hashset_of(ValueReason::Other), + }, + ); + initial_state.define(ctx_place.identifier, value_id); + } + + let param_kind: AbstractValue = if is_function_expression { + AbstractValue { + kind: ValueKind::Mutable, + reason: hashset_of(ValueReason::Other), + } + } else { + AbstractValue { + kind: ValueKind::Frozen, + reason: hashset_of(ValueReason::ReactiveFunctionArgument), + } + }; + + if func.fn_type == ReactFunctionType::Component { + // Component: at most 2 params (props, ref) + let params_len = func.params.len(); + if params_len > 0 { + infer_param(&func.params[0], &mut initial_state, ¶m_kind); + } + if params_len > 1 { + let ref_place = match &func.params[1] { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let value_id = ValueId::new(); + initial_state.initialize( + value_id, + AbstractValue { + kind: ValueKind::Mutable, + reason: hashset_of(ValueReason::Other), + }, + ); + initial_state.define(ref_place.identifier, value_id); + } + } else { + for param in &func.params { + infer_param(param, &mut initial_state, ¶m_kind); + } + } + + let mut queued_states: indexmap::IndexMap = indexmap::IndexMap::new(); + + // Queue helper + fn queue( + queued_states: &mut indexmap::IndexMap, + states_by_block: &HashMap, + block_id: BlockId, + state: InferenceState, + ) { + if let Some(queued_state) = queued_states.get(&block_id) { + let merged = queued_state.merge(&state); + let new_state = merged.unwrap_or_else(|| queued_state.clone()); + queued_states.insert(block_id, new_state); + } else { + let prev_state = states_by_block.get(&block_id); + if let Some(prev) = prev_state { + let next_state = prev.merge(&state); + if let Some(next) = next_state { + queued_states.insert(block_id, next); + } + } else { + queued_states.insert(block_id, state); + } + } + } + + queue( + &mut queued_states, + &states_by_block, + func.body.entry, + initial_state, + ); + + let hoisted_context_declarations = find_hoisted_context_declarations(func, env); + let non_mutating_spreads = find_non_mutated_destructure_spreads(func, env); + + let mut context = Context { + interned_effects: HashMap::new(), + instruction_signature_cache: HashMap::new(), + catch_handlers: HashMap::new(), + is_function_expression, + hoisted_context_declarations, + non_mutating_spreads, + effect_value_id_cache: HashMap::new(), + function_values: HashMap::new(), + function_signature_cache: HashMap::new(), + aliasing_config_temp_cache: HashMap::new(), + }; + + let mut iteration_count = 0; + + while !queued_states.is_empty() { + iteration_count += 1; + if iteration_count > 100 { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "[InferMutationAliasingEffects] Potential infinite loop: A value, temporary \ + place, or effect was not cached properly", + None, + )); + } + + // Collect block IDs to process in order + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + for block_id in block_ids { + let incoming_state = match queued_states.swap_remove(&block_id) { + Some(s) => s, + None => continue, + }; + + states_by_block.insert(block_id, incoming_state.clone()); + let mut state = incoming_state.clone(); + + infer_block(&mut context, &mut state, block_id, func, env)?; + + // Check for uninitialized identifier access (matches TS invariant: + // "Expected value kind to be initialized") + if let Some((uninitialized_id, usage_loc)) = state.uninitialized_access.get() { + let ident_info = env.identifiers.get(uninitialized_id.0 as usize); + let name = ident_info + .and_then(|ident| ident.name.as_ref()) + .map(|n| n.value().to_string()) + .unwrap_or_else(|| "".to_string()); + // Use usage_loc if available, otherwise fall back to identifier's own loc + let error_loc = usage_loc.or_else(|| ident_info.and_then(|i| i.loc)); + // Match TS printPlace format: " name$id:type" + let type_str = ident_info + .map(|ident| { + let ty = &env.types[ident.type_.0 as usize]; + format_type_for_print(ty) + }) + .unwrap_or_default(); + let description = format!(" {}${}{}", name, uninitialized_id.0, type_str); + let diag = CompilerDiagnostic::new( + ErrorCategory::Invariant, + "[InferMutationAliasingEffects] Expected value kind to be initialized", + Some(description), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: error_loc, + message: Some("this is uninitialized".to_string()), + identifier_name: None, + }); + return Err(diag); + } + + // Queue successors + let successors = terminal_successors(&func.body.blocks[&block_id].terminal); + for next_block_id in successors { + queue( + &mut queued_states, + &states_by_block, + next_block_id, + state.clone(), + ); + } + } + } + + Ok(()) +} + +// ============================================================================= +// ValueId: replaces InstructionValue identity as allocation-site key +// ============================================================================= + +/// Unique allocation-site identifier, replacing TS's object-identity on +/// InstructionValue. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +struct ValueId(u32); + +use std::sync::atomic::{AtomicU32, Ordering}; +static NEXT_VALUE_ID: AtomicU32 = AtomicU32::new(1); + +impl ValueId { + fn new() -> Self { + ValueId(NEXT_VALUE_ID.fetch_add(1, Ordering::Relaxed)) + } +} + +// ============================================================================= +// AbstractValue +// ============================================================================= + +#[derive(Debug, Clone)] +struct AbstractValue { + kind: ValueKind, + reason: HashSet, +} + +fn hashset_of(r: ValueReason) -> HashSet { + let mut s = HashSet::new(); + s.insert(r); + s +} + +// ============================================================================= +// InferenceState +// ============================================================================= + +/// The abstract state tracked during inference. +/// Uses interior mutability via a struct with direct fields (no Rc needed since +/// we always have exclusive access in the pass). +#[derive(Debug, Clone)] +struct InferenceState { + is_function_expression: bool, + /// The kind of each value, based on its allocation site + values: HashMap, + /// The set of values pointed to by each identifier + variables: HashMap>, + /// Tracks uninitialized identifier access errors (matches TS invariant). + /// Uses Cell so it can be set from `&self` methods like `kind()`. + /// Stores (IdentifierId, usage_loc) where usage_loc is the source location + /// of the Place that triggered the uninitialized access. + uninitialized_access: std::cell::Cell)>>, +} + +impl InferenceState { + fn empty(_env: &Environment, is_function_expression: bool) -> Self { + InferenceState { + is_function_expression, + values: HashMap::new(), + variables: HashMap::new(), + uninitialized_access: std::cell::Cell::new(None), + } + } + + /// Check the kind of a place, recording the usage location for error + /// reporting. + fn kind_with_loc( + &self, + place_id: IdentifierId, + usage_loc: Option, + ) -> AbstractValue { + let values = match self.variables.get(&place_id) { + Some(v) => v, + None => { + if self.uninitialized_access.get().is_none() { + self.uninitialized_access.set(Some((place_id, usage_loc))); + } + return AbstractValue { + kind: ValueKind::Mutable, + reason: hashset_of(ValueReason::Other), + }; + } + }; + let mut merged_kind: Option = None; + for value_id in values { + let kind = match self.values.get(value_id) { + Some(k) => k, + None => continue, + }; + merged_kind = Some(match merged_kind { + Some(prev) => merge_abstract_values(&prev, kind), + None => kind.clone(), + }); + } + merged_kind.unwrap_or_else(|| AbstractValue { + kind: ValueKind::Mutable, + reason: hashset_of(ValueReason::Other), + }) + } + + fn initialize(&mut self, value_id: ValueId, kind: AbstractValue) { + self.values.insert(value_id, kind); + } + + fn define(&mut self, place_id: IdentifierId, value_id: ValueId) { + let mut set = HashSet::new(); + set.insert(value_id); + self.variables.insert(place_id, set); + } + + fn assign(&mut self, into: IdentifierId, from: IdentifierId) { + let values = match self.variables.get(&from) { + Some(v) => v.clone(), + None => { + // Create a stable value for uninitialized identifiers + // Use a deterministic ID based on the from identifier + let vid = ValueId(from.0 | 0x80000000); + let mut set = HashSet::new(); + set.insert(vid); + if !self.values.contains_key(&vid) { + self.values.insert( + vid, + AbstractValue { + kind: ValueKind::Mutable, + reason: hashset_of(ValueReason::Other), + }, + ); + } + set + } + }; + self.variables.insert(into, values); + } + + fn append_alias(&mut self, place: IdentifierId, value: IdentifierId) { + let new_values = match self.variables.get(&value) { + Some(v) => v.clone(), + None => return, + }; + let prev_values = match self.variables.get(&place) { + Some(v) => v.clone(), + None => return, + }; + let merged: HashSet = prev_values.union(&new_values).copied().collect(); + self.variables.insert(place, merged); + } + + fn is_defined(&self, place_id: IdentifierId) -> bool { + self.variables.contains_key(&place_id) + } + + fn values_for(&self, place_id: IdentifierId) -> Vec { + match self.variables.get(&place_id) { + Some(values) => values.iter().copied().collect(), + None => Vec::new(), + } + } + + #[allow(dead_code)] + fn kind_opt(&self, place_id: IdentifierId) -> Option { + let values = self.variables.get(&place_id)?; + let mut merged_kind: Option = None; + for value_id in values { + let kind = self.values.get(value_id)?; + merged_kind = Some(match merged_kind { + Some(prev) => merge_abstract_values(&prev, kind), + None => kind.clone(), + }); + } + merged_kind + } + + fn kind(&self, place_id: IdentifierId) -> AbstractValue { + self.kind_with_loc(place_id, None) + } + + fn freeze(&mut self, place_id: IdentifierId, reason: ValueReason) -> bool { + // Check if defined first to avoid recording uninitialized access error. + // Freeze on undefined identifiers is a no-op — this matches the TS + // behavior where freeze() is never called on undefined identifiers + // (the invariant in kind() catches this before freeze is reached). + if !self.variables.contains_key(&place_id) { + return false; + } + let value = self.kind(place_id); + match value.kind { + ValueKind::Context | ValueKind::Mutable | ValueKind::MaybeFrozen => { + let value_ids: Vec = self.values_for(place_id); + for vid in value_ids { + self.freeze_value(vid, reason); + } + true + } + ValueKind::Frozen | ValueKind::Global | ValueKind::Primitive => false, + } + } + + fn freeze_value(&mut self, value_id: ValueId, reason: ValueReason) { + self.values.insert( + value_id, + AbstractValue { + kind: ValueKind::Frozen, + reason: hashset_of(reason), + }, + ); + // Note: In TS, this also transitively freezes FunctionExpression + // captures if enableTransitivelyFreezeFunctionExpressions is + // set. We skip that here since we don't have access to the + // function arena from within state. + } + + #[allow(dead_code)] + fn mutate( + &self, + variant: MutateVariant, + place_id: IdentifierId, + env: &Environment, + ) -> MutationResult { + self.mutate_with_loc(variant, place_id, env, None) + } + + fn mutate_with_loc( + &self, + variant: MutateVariant, + place_id: IdentifierId, + env: &Environment, + usage_loc: Option, + ) -> MutationResult { + let ty = &env.types[env.identifiers[place_id.0 as usize].type_.0 as usize]; + if react_compiler_hir::is_ref_or_ref_value(ty) { + return MutationResult::MutateRef; + } + let kind = self.kind_with_loc(place_id, usage_loc).kind; + match variant { + MutateVariant::MutateConditionally | MutateVariant::MutateTransitiveConditionally => { + match kind { + ValueKind::Mutable | ValueKind::Context => MutationResult::Mutate, + _ => MutationResult::None, + } + } + MutateVariant::Mutate | MutateVariant::MutateTransitive => match kind { + ValueKind::Mutable | ValueKind::Context => MutationResult::Mutate, + ValueKind::Primitive => MutationResult::None, + ValueKind::Frozen | ValueKind::MaybeFrozen => MutationResult::MutateFrozen, + ValueKind::Global => MutationResult::MutateGlobal, + }, + } + } + + fn merge(&self, other: &InferenceState) -> Option { + let mut next_values: Option> = None; + let mut next_variables: Option>> = None; + + // Merge values present in both + for (id, this_value) in &self.values { + if let Some(other_value) = other.values.get(id) { + let merged = merge_abstract_values(this_value, other_value); + if merged.kind != this_value.kind + || !is_superset(&this_value.reason, &merged.reason) + { + let nv = next_values.get_or_insert_with(|| self.values.clone()); + nv.insert(*id, merged); + } + } + } + // Add values only in other + for (id, other_value) in &other.values { + if !self.values.contains_key(id) { + let nv = next_values.get_or_insert_with(|| self.values.clone()); + nv.insert(*id, other_value.clone()); + } + } + + // Merge variables present in both + for (id, this_values) in &self.variables { + if let Some(other_values) = other.variables.get(id) { + let mut has_new = false; + for ov in other_values { + if !this_values.contains(ov) { + has_new = true; + break; + } + } + if has_new { + let nvars = next_variables.get_or_insert_with(|| self.variables.clone()); + let merged: HashSet = + this_values.union(other_values).copied().collect(); + nvars.insert(*id, merged); + } + } + } + // Add variables only in other + for (id, other_values) in &other.variables { + if !self.variables.contains_key(id) { + let nvars = next_variables.get_or_insert_with(|| self.variables.clone()); + nvars.insert(*id, other_values.clone()); + } + } + + if next_variables.is_none() && next_values.is_none() { + None + } else { + Some(InferenceState { + is_function_expression: self.is_function_expression, + values: next_values.unwrap_or_else(|| self.values.clone()), + variables: next_variables.unwrap_or_else(|| self.variables.clone()), + uninitialized_access: std::cell::Cell::new(None), + }) + } + } + + fn infer_phi( + &mut self, + phi_place_id: IdentifierId, + phi_operands: &indexmap::IndexMap, + ) { + let mut values: HashSet = HashSet::new(); + for (_, operand) in phi_operands { + if let Some(operand_values) = self.variables.get(&operand.identifier) { + for v in operand_values { + values.insert(*v); + } + } + // If not found, it's a backedge that will be handled later by merge + } + if !values.is_empty() { + self.variables.insert(phi_place_id, values); + } + } +} + +fn is_superset(a: &HashSet, b: &HashSet) -> bool { + b.iter().all(|x| a.contains(x)) +} + +#[derive(Debug, Clone, Copy)] +enum MutateVariant { + Mutate, + MutateConditionally, + MutateTransitive, + MutateTransitiveConditionally, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MutationResult { + None, + Mutate, + MutateFrozen, + MutateGlobal, + MutateRef, +} + +// ============================================================================= +// Context +// ============================================================================= + +struct Context { + interned_effects: HashMap, + instruction_signature_cache: HashMap, + catch_handlers: HashMap, + is_function_expression: bool, + hoisted_context_declarations: HashMap>, + non_mutating_spreads: HashSet, + /// Cache of ValueIds keyed by effect hash, ensuring stable allocation-site + /// identity across fixpoint iterations. Mirrors TS + /// `effectInstructionValueCache`. + effect_value_id_cache: HashMap, + /// Maps ValueId to FunctionId for function expressions, so we can look up + /// locally-declared functions when processing Apply effects. + function_values: HashMap, + /// Cache of function expression signatures, keyed by FunctionId + function_signature_cache: HashMap, + /// Cache of temporary places created for aliasing signature config + /// temporaries. Keyed by (lvalue_identifier_id, temp_name) to ensure + /// stable allocation across fixpoint iterations. + aliasing_config_temp_cache: HashMap<(IdentifierId, String), Place>, +} + +impl Context { + fn intern_effect(&mut self, effect: AliasingEffect) -> AliasingEffect { + let hash = hash_effect(&effect); + self.interned_effects.entry(hash).or_insert(effect).clone() + } + + /// Get or create a stable ValueId for a given effect, ensuring fixpoint + /// convergence. + fn get_or_create_value_id(&mut self, effect: &AliasingEffect) -> ValueId { + let hash = hash_effect(effect); + *self + .effect_value_id_cache + .entry(hash) + .or_insert_with(ValueId::new) + } +} + +struct InstructionSignature { + effects: Vec, +} + +// ============================================================================= +// Helper: hash_effect +// ============================================================================= + +fn hash_effect(effect: &AliasingEffect) -> String { + match effect { + AliasingEffect::Apply { + receiver, + function, + mutates_function, + args, + into, + .. + } => { + let args_str: Vec = args + .iter() + .map(|a| match a { + PlaceOrSpreadOrHole::Hole => String::new(), + PlaceOrSpreadOrHole::Place(p) => format!("{}", p.identifier.0), + PlaceOrSpreadOrHole::Spread(s) => format!("...{}", s.place.identifier.0), + }) + .collect(); + format!( + "Apply:{}:{}:{}:{}:{}", + receiver.identifier.0, + function.identifier.0, + mutates_function, + args_str.join(","), + into.identifier.0 + ) + } + AliasingEffect::CreateFrom { from, into } => { + format!("CreateFrom:{}:{}", from.identifier.0, into.identifier.0) + } + AliasingEffect::ImmutableCapture { from, into } => format!( + "ImmutableCapture:{}:{}", + from.identifier.0, into.identifier.0 + ), + AliasingEffect::Assign { from, into } => { + format!("Assign:{}:{}", from.identifier.0, into.identifier.0) + } + AliasingEffect::Alias { from, into } => { + format!("Alias:{}:{}", from.identifier.0, into.identifier.0) + } + AliasingEffect::Capture { from, into } => { + format!("Capture:{}:{}", from.identifier.0, into.identifier.0) + } + AliasingEffect::MaybeAlias { from, into } => { + format!("MaybeAlias:{}:{}", from.identifier.0, into.identifier.0) + } + AliasingEffect::Create { + into, + value, + reason, + } => format!("Create:{}:{:?}:{:?}", into.identifier.0, value, reason), + AliasingEffect::Freeze { value, reason } => { + format!("Freeze:{}:{:?}", value.identifier.0, reason) + } + AliasingEffect::Impure { place, .. } => format!("Impure:{}", place.identifier.0), + AliasingEffect::Render { place } => format!("Render:{}", place.identifier.0), + AliasingEffect::MutateFrozen { place, error } => format!( + "MutateFrozen:{}:{}:{:?}", + place.identifier.0, error.reason, error.description + ), + AliasingEffect::MutateGlobal { place, error } => format!( + "MutateGlobal:{}:{}:{:?}", + place.identifier.0, error.reason, error.description + ), + AliasingEffect::Mutate { value, .. } => format!("Mutate:{}", value.identifier.0), + AliasingEffect::MutateConditionally { value } => { + format!("MutateConditionally:{}", value.identifier.0) + } + AliasingEffect::MutateTransitive { value } => { + format!("MutateTransitive:{}", value.identifier.0) + } + AliasingEffect::MutateTransitiveConditionally { value } => { + format!("MutateTransitiveConditionally:{}", value.identifier.0) + } + AliasingEffect::CreateFunction { + into, + function_id, + captures, + } => { + let cap_str: Vec = captures + .iter() + .map(|p| format!("{}", p.identifier.0)) + .collect(); + format!( + "CreateFunction:{}:{}:{}", + into.identifier.0, + function_id.0, + cap_str.join(",") + ) + } + } +} + +// ============================================================================= +// merge helpers +// ============================================================================= + +fn merge_abstract_values(a: &AbstractValue, b: &AbstractValue) -> AbstractValue { + let kind = merge_value_kinds(a.kind, b.kind); + if kind == a.kind && kind == b.kind && is_superset(&a.reason, &b.reason) { + return a.clone(); + } + let mut reason = a.reason.clone(); + for r in &b.reason { + reason.insert(*r); + } + AbstractValue { kind, reason } +} + +fn merge_value_kinds(a: ValueKind, b: ValueKind) -> ValueKind { + if a == b { + return a; + } + if a == ValueKind::MaybeFrozen || b == ValueKind::MaybeFrozen { + return ValueKind::MaybeFrozen; + } + if a == ValueKind::Mutable || b == ValueKind::Mutable { + if a == ValueKind::Frozen || b == ValueKind::Frozen { + return ValueKind::MaybeFrozen; + } else if a == ValueKind::Context || b == ValueKind::Context { + return ValueKind::Context; + } else { + return ValueKind::Mutable; + } + } + if a == ValueKind::Context || b == ValueKind::Context { + if a == ValueKind::Frozen || b == ValueKind::Frozen { + return ValueKind::MaybeFrozen; + } else { + return ValueKind::Context; + } + } + if a == ValueKind::Frozen || b == ValueKind::Frozen { + return ValueKind::Frozen; + } + if a == ValueKind::Global || b == ValueKind::Global { + return ValueKind::Global; + } + ValueKind::Primitive +} + +// ============================================================================= +// Pre-passes +// ============================================================================= + +fn find_hoisted_context_declarations( + func: &HirFunction, + env: &Environment, +) -> HashMap> { + let mut hoisted: HashMap> = HashMap::new(); + + fn visit( + hoisted: &mut HashMap>, + place: &Place, + env: &Environment, + ) { + let decl_id = env.identifiers[place.identifier.0 as usize].declaration_id; + if hoisted.contains_key(&decl_id) && hoisted.get(&decl_id).unwrap().is_none() { + hoisted.insert(decl_id, Some(place.clone())); + } + } + + for (_block_id, block) in &func.body.blocks { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::DeclareContext { lvalue, .. } => { + let kind = lvalue.kind; + if kind == InstructionKind::HoistedConst + || kind == InstructionKind::HoistedFunction + || kind == InstructionKind::HoistedLet + { + let decl_id = + env.identifiers[lvalue.place.identifier.0 as usize].declaration_id; + hoisted.insert(decl_id, None); + } + } + _ => { + for operand in visitors::each_instruction_value_operand(&instr.value, env) { + visit(&mut hoisted, &operand, env); + } + } + } + } + for operand in visitors::each_terminal_operand(&block.terminal) { + visit(&mut hoisted, &operand, env); + } + } + hoisted +} + +fn find_non_mutated_destructure_spreads( + func: &HirFunction, + env: &Environment, +) -> HashSet { + let mut known_frozen: HashSet = HashSet::new(); + if func.fn_type == ReactFunctionType::Component { + if let Some(param) = func.params.first() { + if let ParamPattern::Place(p) = param { + known_frozen.insert(p.identifier); + } + } + } else { + for param in &func.params { + if let ParamPattern::Place(p) = param { + known_frozen.insert(p.identifier); + } + } + } + + let mut candidate_non_mutating_spreads: HashMap = HashMap::new(); + for (_block_id, block) in &func.body.blocks { + if !candidate_non_mutating_spreads.is_empty() { + for phi in &block.phis { + for (_, operand) in &phi.operands { + if let Some(spread) = candidate_non_mutating_spreads + .get(&operand.identifier) + .copied() + { + candidate_non_mutating_spreads.remove(&spread); + } + } + } + } + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + match &instr.value { + InstructionValue::Destructure { lvalue, value, .. } => { + if !known_frozen.contains(&value.identifier) { + continue; + } + if !(lvalue.kind == InstructionKind::Let + || lvalue.kind == InstructionKind::Const) + { + continue; + } + match &lvalue.pattern { + react_compiler_hir::Pattern::Object(obj_pat) => { + for prop in &obj_pat.properties { + if let react_compiler_hir::ObjectPropertyOrSpread::Spread(s) = prop + { + candidate_non_mutating_spreads + .insert(s.place.identifier, s.place.identifier); + } + } + } + _ => continue, + } + } + InstructionValue::LoadLocal { place, .. } => { + if let Some(spread) = candidate_non_mutating_spreads + .get(&place.identifier) + .copied() + { + candidate_non_mutating_spreads.insert(lvalue_id, spread); + } + } + InstructionValue::StoreLocal { + lvalue: sl, + value: sv, + .. + } => { + if let Some(spread) = + candidate_non_mutating_spreads.get(&sv.identifier).copied() + { + candidate_non_mutating_spreads.insert(lvalue_id, spread); + candidate_non_mutating_spreads.insert(sl.place.identifier, spread); + } + } + InstructionValue::JsxFragment { .. } | InstructionValue::JsxExpression { .. } => { + // Passing objects created with spread to jsx can't mutate + // them + } + InstructionValue::PropertyLoad { .. } => { + // Properties must be frozen since the original value was + // frozen + } + InstructionValue::CallExpression { callee, .. } + | InstructionValue::MethodCall { + property: callee, .. + } => { + let callee_ty = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if get_hook_kind_for_type(env, callee_ty) + .ok() + .flatten() + .is_some() + { + if !is_ref_or_ref_value_for_id(env, lvalue_id) { + known_frozen.insert(lvalue_id); + } + } else if !candidate_non_mutating_spreads.is_empty() { + for operand in visitors::each_instruction_value_operand(&instr.value, env) { + if let Some(spread) = candidate_non_mutating_spreads + .get(&operand.identifier) + .copied() + { + candidate_non_mutating_spreads.remove(&spread); + } + } + } + } + _ => { + if !candidate_non_mutating_spreads.is_empty() { + for operand in visitors::each_instruction_value_operand(&instr.value, env) { + if let Some(spread) = candidate_non_mutating_spreads + .get(&operand.identifier) + .copied() + { + candidate_non_mutating_spreads.remove(&spread); + } + } + } + } + } + } + } + + let mut non_mutating: HashSet = HashSet::new(); + for (key, value) in &candidate_non_mutating_spreads { + if key == value { + non_mutating.insert(*key); + } + } + non_mutating +} + +// ============================================================================= +// inferParam +// ============================================================================= + +fn infer_param(param: &ParamPattern, state: &mut InferenceState, param_kind: &AbstractValue) { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let value_id = ValueId::new(); + state.initialize(value_id, param_kind.clone()); + state.define(place.identifier, value_id); +} + +// ============================================================================= +// inferBlock +// ============================================================================= + +fn infer_block( + context: &mut Context, + state: &mut InferenceState, + block_id: BlockId, + func: &mut HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + let block = &func.body.blocks[&block_id]; + + // Process phis + let phis: Vec<(IdentifierId, indexmap::IndexMap)> = block + .phis + .iter() + .map(|phi| (phi.place.identifier, phi.operands.clone())) + .collect(); + for (place_id, operands) in &phis { + state.infer_phi(*place_id, operands); + } + + // Process instructions + let instr_ids: Vec = block.instructions.iter().map(|id| id.0).collect(); + for instr_idx in &instr_ids { + let instr_index = *instr_idx as usize; + + // Compute signature if not cached + if !context.instruction_signature_cache.contains_key(instr_idx) { + let sig = compute_signature_for_instruction( + context, + env, + &func.instructions[instr_index], + func, + ); + context.instruction_signature_cache.insert(*instr_idx, sig); + } + + // Apply signature + let effects = apply_signature( + context, + state, + *instr_idx, + &func.instructions[instr_index], + env, + func, + )?; + func.instructions[instr_index].effects = effects; + } + + // Process terminal + // Determine what terminal action to take without holding borrows + enum TerminalAction { + Try { handler: BlockId, binding: Place }, + MaybeThrow { handler_id: BlockId }, + Return, + None, + } + let action = { + let block = &func.body.blocks[&block_id]; + match &block.terminal { + react_compiler_hir::Terminal::Try { + handler, + handler_binding: Some(binding), + .. + } => TerminalAction::Try { + handler: *handler, + binding: binding.clone(), + }, + react_compiler_hir::Terminal::MaybeThrow { + handler: Some(handler_id), + .. + } => TerminalAction::MaybeThrow { + handler_id: *handler_id, + }, + react_compiler_hir::Terminal::Return { .. } => TerminalAction::Return, + _ => TerminalAction::None, + } + }; + + match action { + TerminalAction::Try { handler, binding } => { + context.catch_handlers.insert(handler, binding); + } + TerminalAction::MaybeThrow { handler_id } => { + if let Some(handler_param) = context.catch_handlers.get(&handler_id).cloned() { + if state.is_defined(handler_param.identifier) { + let mut terminal_effects: Vec = Vec::new(); + for instr_idx in &instr_ids { + let instr = &func.instructions[*instr_idx as usize]; + match &instr.value { + InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } => { + state.append_alias( + handler_param.identifier, + instr.lvalue.identifier, + ); + let kind = state.kind(instr.lvalue.identifier).kind; + if kind == ValueKind::Mutable || kind == ValueKind::Context { + terminal_effects.push(context.intern_effect( + AliasingEffect::Alias { + from: instr.lvalue.clone(), + into: handler_param.clone(), + }, + )); + } + } + _ => {} + } + } + let block_mut = func.body.blocks.get_mut(&block_id).unwrap(); + if let react_compiler_hir::Terminal::MaybeThrow { + effects: ref mut term_effects, + .. + } = block_mut.terminal + { + *term_effects = if terminal_effects.is_empty() { + None + } else { + Some(terminal_effects) + }; + } + } + } + } + TerminalAction::Return => { + if !context.is_function_expression { + let block_mut = func.body.blocks.get_mut(&block_id).unwrap(); + if let react_compiler_hir::Terminal::Return { + ref value, + effects: ref mut term_effects, + .. + } = block_mut.terminal + { + *term_effects = Some(vec![context.intern_effect(AliasingEffect::Freeze { + value: value.clone(), + reason: ValueReason::JsxCaptured, + })]); + } + } + } + TerminalAction::None => {} + } + Ok(()) +} + +// ============================================================================= +// applySignature +// ============================================================================= + +fn apply_signature( + context: &mut Context, + state: &mut InferenceState, + instr_idx: u32, + instr: &react_compiler_hir::Instruction, + env: &mut Environment, + func: &HirFunction, +) -> Result>, CompilerDiagnostic> { + let mut effects: Vec = Vec::new(); + + // For function instructions, validate frozen mutation + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + if let Some(ref aliasing_effects) = inner_func.aliasing_effects { + let context_ids: HashSet = + inner_func.context.iter().map(|p| p.identifier).collect(); + for effect in aliasing_effects { + let (mutate_value, is_mutate) = match effect { + AliasingEffect::Mutate { value, .. } => (value, true), + AliasingEffect::MutateTransitive { value } => (value, false), + _ => continue, + }; + if !context_ids.contains(&mutate_value.identifier) { + continue; + } + if !state.is_defined(mutate_value.identifier) { + continue; + } + let value_abstract = state.kind(mutate_value.identifier); + if value_abstract.kind == ValueKind::Frozen { + let reason_str = get_write_error_reason(&value_abstract); + let ident = &env.identifiers[mutate_value.identifier.0 as usize]; + let variable = match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => { + format!("`{}`", n) + } + _ => "value".to_string(), + }; + let mut diagnostic = CompilerDiagnostic::new( + ErrorCategory::Immutability, + "This value cannot be modified", + Some(reason_str), + ); + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc: mutate_value.loc, + message: Some(format!("{} cannot be modified", variable)), + identifier_name: None, + }, + ); + if is_mutate { + if let AliasingEffect::Mutate { + reason: Some(MutationReason::AssignCurrentProperty), + .. + } = effect + { + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Hint { + message: "Hint: If this value is a Ref (value returned by \ + `useRef()`), rename the variable to end in \ + \"Ref\"." + .to_string(), + }, + ); + } + } + effects.push(AliasingEffect::MutateFrozen { + place: mutate_value.clone(), + error: diagnostic, + }); + } + } + } + } + _ => {} + } + + // Track which values we've already initialized + let mut initialized: HashSet = HashSet::new(); + + // Get the cached signature effects + let sig = context.instruction_signature_cache.get(&instr_idx).unwrap(); + let sig_effects: Vec = sig.effects.clone(); + + for effect in &sig_effects { + apply_effect( + context, + state, + effect.clone(), + &mut initialized, + &mut effects, + env, + func, + )?; + } + + // If lvalue is not yet defined, initialize it with a default value. + // The TS version asserts this as an invariant, but the Rust port may have + // edge cases where effects don't cover the lvalue (e.g. missing signature + // entries). + if !state.is_defined(instr.lvalue.identifier) { + let vid = ValueId(instr.lvalue.identifier.0 | 0x80000000); + state.initialize( + vid, + AbstractValue { + kind: ValueKind::Mutable, + reason: hashset_of(ValueReason::Other), + }, + ); + state.define(instr.lvalue.identifier, vid); + } + + Ok(if effects.is_empty() { + None + } else { + Some(effects) + }) +} + +// ============================================================================= +// Transitive freeze helper +// ============================================================================= + +/// Recursively freeze through FunctionExpression captures. If `value_id` +/// corresponds to a FunctionExpression, freeze each of its context captures +/// and recurse into any that are themselves FunctionExpressions. This matches +/// the TS `freezeValue` → `freeze` → `freezeValue` recursion chain. +fn freeze_function_captures_transitive( + state: &mut InferenceState, + context: &Context, + env: &Environment, + value_id: ValueId, + reason: ValueReason, +) { + if let Some(&func_id) = context.function_values.get(&value_id) { + let ctx_ids: Vec = env.functions[func_id.0 as usize] + .context + .iter() + .map(|p| p.identifier) + .collect(); + for ctx_id in ctx_ids { + // Replicate InferenceState::freeze() logic inline — + // we need to recurse with context/env which freeze() doesn't have. + if !state.variables.contains_key(&ctx_id) { + continue; + } + let kind = state.kind(ctx_id).kind; + match kind { + ValueKind::Context | ValueKind::Mutable | ValueKind::MaybeFrozen => { + let vids: Vec = state.values_for(ctx_id); + for vid in vids { + state.freeze_value(vid, reason); + // Recurse into nested function captures + freeze_function_captures_transitive(state, context, env, vid, reason); + } + } + ValueKind::Frozen | ValueKind::Global | ValueKind::Primitive => { + // Already frozen or immutable — no-op + } + } + } + } +} + +// ============================================================================= +// applyEffect +// ============================================================================= + +fn apply_effect( + context: &mut Context, + state: &mut InferenceState, + effect: AliasingEffect, + initialized: &mut HashSet, + effects: &mut Vec, + env: &mut Environment, + func: &HirFunction, +) -> Result<(), CompilerDiagnostic> { + let effect = context.intern_effect(effect); + match effect { + AliasingEffect::Freeze { ref value, reason } => { + let did_freeze = state.freeze(value.identifier, reason); + if did_freeze { + effects.push(effect.clone()); + // Transitively freeze FunctionExpression captures if enabled + // (matches TS freezeValue which recurses into func.context) + let enable_transitive = env.config.enable_preserve_existing_memoization_guarantees + || env.config.enable_transitively_freeze_function_expressions; + if enable_transitive { + // Recursively freeze through function captures. The TS + // freezeValue() calls freeze() on each capture, which + // calls freezeValue() again — creating a transitive + // closure through arbitrarily nested function captures. + let value_ids: Vec = state.values_for(value.identifier); + for vid in &value_ids { + freeze_function_captures_transitive(state, context, env, *vid, reason); + } + } + } + } + AliasingEffect::Create { + ref into, + value: kind, + reason, + } => { + assert!( + !initialized.contains(&into.identifier), + "[InferMutationAliasingEffects] Cannot re-initialize variable within an \ + instruction" + ); + initialized.insert(into.identifier); + let value_id = context.get_or_create_value_id(&effect); + state.initialize( + value_id, + AbstractValue { + kind, + reason: hashset_of(reason), + }, + ); + state.define(into.identifier, value_id); + effects.push(effect.clone()); + } + AliasingEffect::ImmutableCapture { ref from, .. } => { + let kind = state.kind(from.identifier).kind; + match kind { + ValueKind::Global | ValueKind::Primitive => { + // no-op: don't track data flow for copy types + } + _ => { + effects.push(effect.clone()); + } + } + } + AliasingEffect::CreateFrom { ref from, ref into } => { + assert!( + !initialized.contains(&into.identifier), + "[InferMutationAliasingEffects] Cannot re-initialize variable within an \ + instruction" + ); + initialized.insert(into.identifier); + let from_value = state.kind(from.identifier); + let value_id = context.get_or_create_value_id(&effect); + state.initialize( + value_id, + AbstractValue { + kind: from_value.kind, + reason: from_value.reason.clone(), + }, + ); + state.define(into.identifier, value_id); + match from_value.kind { + ValueKind::Primitive | ValueKind::Global => { + let first_reason = primary_reason(&from_value.reason); + effects.push(AliasingEffect::Create { + value: from_value.kind, + into: into.clone(), + reason: first_reason, + }); + } + ValueKind::Frozen => { + let first_reason = primary_reason(&from_value.reason); + effects.push(AliasingEffect::Create { + value: from_value.kind, + into: into.clone(), + reason: first_reason, + }); + apply_effect( + context, + state, + AliasingEffect::ImmutableCapture { + from: from.clone(), + into: into.clone(), + }, + initialized, + effects, + env, + func, + )?; + } + _ => { + effects.push(effect.clone()); + } + } + } + AliasingEffect::CreateFunction { + ref captures, + function_id, + ref into, + } => { + assert!( + !initialized.contains(&into.identifier), + "[InferMutationAliasingEffects] Cannot re-initialize variable within an \ + instruction" + ); + initialized.insert(into.identifier); + effects.push(effect.clone()); + + // Check if function is mutable + let has_captures = captures.iter().any(|capture| { + if !state.is_defined(capture.identifier) { + return false; + } + let k = state.kind(capture.identifier).kind; + k == ValueKind::Context || k == ValueKind::Mutable + }); + + let inner_func = &env.functions[function_id.0 as usize]; + let has_tracked_side_effects = inner_func + .aliasing_effects + .as_ref() + .map(|effs| { + effs.iter().any(|e| { + matches!( + e, + AliasingEffect::MutateFrozen { .. } + | AliasingEffect::MutateGlobal { .. } + | AliasingEffect::Impure { .. } + ) + }) + }) + .unwrap_or(false); + + let captures_ref = inner_func + .context + .iter() + .any(|operand| is_ref_or_ref_value_for_id(env, operand.identifier)); + + let is_mutable = has_captures || has_tracked_side_effects || captures_ref; + + // Update context variable effects + let context_places: Vec = inner_func.context.clone(); + for operand in &context_places { + if operand.effect != Effect::Capture { + continue; + } + if !state.is_defined(operand.identifier) { + continue; + } + let kind = state.kind(operand.identifier).kind; + if kind == ValueKind::Primitive + || kind == ValueKind::Frozen + || kind == ValueKind::Global + { + // Downgrade to Read - we need to mutate the inner function + let inner_func_mut = &mut env.functions[function_id.0 as usize]; + for ctx in &mut inner_func_mut.context { + if ctx.identifier == operand.identifier && ctx.effect == Effect::Capture { + ctx.effect = Effect::Read; + } + } + } + } + + let value_id = context.get_or_create_value_id(&effect); + // Track this value as a function expression so Apply can look it up + context.function_values.insert(value_id, function_id); + state.initialize( + value_id, + AbstractValue { + kind: if is_mutable { + ValueKind::Mutable + } else { + ValueKind::Frozen + }, + reason: HashSet::new(), + }, + ); + state.define(into.identifier, value_id); + + for capture in captures { + apply_effect( + context, + state, + AliasingEffect::Capture { + from: capture.clone(), + into: into.clone(), + }, + initialized, + effects, + env, + func, + )?; + } + } + AliasingEffect::MaybeAlias { ref from, ref into } + | AliasingEffect::Alias { ref from, ref into } + | AliasingEffect::Capture { ref from, ref into } => { + let is_capture = matches!(effect, AliasingEffect::Capture { .. }); + let is_maybe_alias = matches!(effect, AliasingEffect::MaybeAlias { .. }); + // For Alias, destination must already be initialized (Capture/MaybeAlias are + // exempt) + assert!( + is_capture || is_maybe_alias || initialized.contains(&into.identifier), + "[InferMutationAliasingEffects] Expected destination to already be initialized \ + within this instruction" + ); + + // Check destination kind + let into_kind = state.kind_with_loc(into.identifier, into.loc).kind; + let destination_type = match into_kind { + ValueKind::Context => Some("context"), + ValueKind::Mutable | ValueKind::MaybeFrozen => Some("mutable"), + _ => None, + }; + + let from_kind = state.kind_with_loc(from.identifier, from.loc).kind; + let source_type = match from_kind { + ValueKind::Context => Some("context"), + ValueKind::Global | ValueKind::Primitive => None, + ValueKind::MaybeFrozen | ValueKind::Frozen => Some("frozen"), + ValueKind::Mutable => Some("mutable"), + }; + + if source_type == Some("frozen") { + apply_effect( + context, + state, + AliasingEffect::ImmutableCapture { + from: from.clone(), + into: into.clone(), + }, + initialized, + effects, + env, + func, + )?; + } else if (source_type == Some("mutable") && destination_type == Some("mutable")) + || is_maybe_alias + { + effects.push(effect.clone()); + } else if (source_type == Some("context") && destination_type.is_some()) + || (source_type == Some("mutable") && destination_type == Some("context")) + { + apply_effect( + context, + state, + AliasingEffect::MaybeAlias { + from: from.clone(), + into: into.clone(), + }, + initialized, + effects, + env, + func, + )?; + } + } + AliasingEffect::Assign { ref from, ref into } => { + assert!( + !initialized.contains(&into.identifier), + "[InferMutationAliasingEffects] Cannot re-initialize variable within an \ + instruction" + ); + initialized.insert(into.identifier); + let from_value = state.kind_with_loc(from.identifier, from.loc); + match from_value.kind { + ValueKind::Frozen => { + apply_effect( + context, + state, + AliasingEffect::ImmutableCapture { + from: from.clone(), + into: into.clone(), + }, + initialized, + effects, + env, + func, + )?; + let cache_key = + format!("Assign_frozen:{}:{}", from.identifier.0, into.identifier.0); + let value_id = *context + .effect_value_id_cache + .entry(cache_key) + .or_insert_with(ValueId::new); + state.initialize( + value_id, + AbstractValue { + kind: from_value.kind, + reason: from_value.reason.clone(), + }, + ); + state.define(into.identifier, value_id); + } + ValueKind::Global | ValueKind::Primitive => { + let cache_key = + format!("Assign_copy:{}:{}", from.identifier.0, into.identifier.0); + let value_id = *context + .effect_value_id_cache + .entry(cache_key) + .or_insert_with(ValueId::new); + state.initialize( + value_id, + AbstractValue { + kind: from_value.kind, + reason: from_value.reason.clone(), + }, + ); + state.define(into.identifier, value_id); + } + _ => { + state.assign(into.identifier, from.identifier); + effects.push(effect.clone()); + } + } + } + AliasingEffect::Apply { + ref receiver, + ref function, + mutates_function, + ref args, + ref into, + ref signature, + ref loc, + } => { + // First, check if the callee is a locally-declared function expression + // whose aliasing effects we already know (TS lines 1016-1068) + if state.is_defined(function.identifier) { + let function_values = state.values_for(function.identifier); + if function_values.len() == 1 { + let value_id = function_values[0]; + if let Some(func_id) = context.function_values.get(&value_id).copied() { + let inner_func = &env.functions[func_id.0 as usize]; + if inner_func.aliasing_effects.is_some() { + // Build or retrieve the signature from the function expression + if !context.function_signature_cache.contains_key(&func_id) { + let sig = build_signature_from_function_expression(env, func_id); + context.function_signature_cache.insert(func_id, sig); + } + let sig = context + .function_signature_cache + .get(&func_id) + .unwrap() + .clone(); + let inner_func = &env.functions[func_id.0 as usize]; + let context_places: Vec = inner_func.context.clone(); + let sig_effects = compute_effects_for_aliasing_signature( + env, + &sig, + into, + receiver, + args, + &context_places, + loc.as_ref(), + )?; + if let Some(sig_effs) = sig_effects { + // Conditionally mutate the function itself first + apply_effect( + context, + state, + AliasingEffect::MutateTransitiveConditionally { + value: function.clone(), + }, + initialized, + effects, + env, + func, + )?; + for se in sig_effs { + apply_effect( + context, + state, + se, + initialized, + effects, + env, + func, + )?; + } + return Ok(()); + } + } + } + } + } + if let Some(sig) = signature { + // Check known_incompatible (TS line 2351-2370) + if let Some(ref incompatible_msg) = sig.known_incompatible { + if env.enable_validations() { + let mut diagnostic = CompilerDiagnostic::new( + ErrorCategory::IncompatibleLibrary, + "Use of incompatible library", + Some( + "This API returns functions which cannot be memoized without \ + leading to stale UI. To prevent this, by default React Compiler \ + will skip memoizing this component/hook. However, you may see \ + issues if values from this API are passed to other \ + components/hooks that are memoized" + .to_string(), + ), + ); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: receiver.loc, + message: Some(incompatible_msg.clone()), + identifier_name: None, + }); + // TS throws here, aborting compilation for this function + return Err(diagnostic); + } + } + + if let Some(ref aliasing) = sig.aliasing { + let sig_effects = compute_effects_for_aliasing_signature_config( + env, + aliasing, + into, + receiver, + args, + &[], + loc.as_ref(), + &mut context.aliasing_config_temp_cache, + )?; + if let Some(sig_effs) = sig_effects { + for se in sig_effs { + apply_effect(context, state, se, initialized, effects, env, func)?; + } + return Ok(()); + } + } + + // Legacy signature + let mut todo_errors: Vec = + Vec::new(); + let legacy_effects = compute_effects_for_legacy_signature( + state, + sig, + into, + receiver, + args, + loc.as_ref(), + env, + &context.function_values, + &mut todo_errors, + ); + // Todo errors should short-circuit (TS throws throwTodo) + if let Some(err_detail) = todo_errors.into_iter().next() { + return Err(CompilerDiagnostic::from_detail(err_detail)); + } + for le in legacy_effects { + apply_effect(context, state, le, initialized, effects, env, func)?; + } + } else { + // No signature: default behavior + apply_effect( + context, + state, + AliasingEffect::Create { + into: into.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }, + initialized, + effects, + env, + func, + )?; + + let all_operands = build_apply_operands(receiver, function, args); + for (operand, _is_function_operand, is_spread) in &all_operands { + // In TS, the check is `operand !== effect.function || effect.mutatesFunction`. + // This compares by reference identity, so for CallExpression/NewExpression + // where receiver === function, BOTH are skipped when !mutatesFunction. + if operand.identifier == function.identifier && !mutates_function { + // Don't mutate callee for non-mutating calls + } else { + apply_effect( + context, + state, + AliasingEffect::MutateTransitiveConditionally { + value: operand.clone(), + }, + initialized, + effects, + env, + func, + )?; + } + + if *is_spread { + let ty = &env.types + [env.identifiers[operand.identifier.0 as usize].type_.0 as usize]; + if let Some(mutate_iter) = conditionally_mutate_iterator(operand, ty) { + apply_effect( + context, + state, + mutate_iter, + initialized, + effects, + env, + func, + )?; + } + } + + apply_effect( + context, + state, + AliasingEffect::MaybeAlias { + from: operand.clone(), + into: into.clone(), + }, + initialized, + effects, + env, + func, + )?; + + // In TS, `other === arg` compares the Place extracted from + // `otherArg` with the original `arg` element. For Identifier + // args, the extracted Place IS the arg, so this is a reference + // identity check. For Spread args, the extracted Place is + // `.place` which is never `===` the Spread wrapper object, + // so NO pairs are skipped when the outer arg is a Spread + // (including self-pairs, producing self-captures). + for (other, _other_is_func, _other_is_spread) in &all_operands { + if !is_spread && other.identifier == operand.identifier { + continue; + } + apply_effect( + context, + state, + AliasingEffect::Capture { + from: operand.clone(), + into: other.clone(), + }, + initialized, + effects, + env, + func, + )?; + } + } + } + } + ref eff @ (AliasingEffect::Mutate { .. } + | AliasingEffect::MutateConditionally { .. } + | AliasingEffect::MutateTransitive { .. } + | AliasingEffect::MutateTransitiveConditionally { .. }) => { + let (mutate_place, variant) = match eff { + AliasingEffect::Mutate { value, .. } => (value, MutateVariant::Mutate), + AliasingEffect::MutateConditionally { value } => { + (value, MutateVariant::MutateConditionally) + } + AliasingEffect::MutateTransitive { value } => { + (value, MutateVariant::MutateTransitive) + } + AliasingEffect::MutateTransitiveConditionally { value } => { + (value, MutateVariant::MutateTransitiveConditionally) + } + _ => unreachable!(), + }; + let value = mutate_place; + let mutation_kind = state.mutate_with_loc(variant, value.identifier, env, value.loc); + if mutation_kind == MutationResult::Mutate { + effects.push(effect.clone()); + } else if mutation_kind == MutationResult::MutateRef { + // no-op + } else if mutation_kind != MutationResult::None + && matches!( + variant, + MutateVariant::Mutate | MutateVariant::MutateTransitive + ) + { + let abstract_value = state.kind(value.identifier); + + let ident = &env.identifiers[value.identifier.0 as usize]; + let decl_id = ident.declaration_id; + + if mutation_kind == MutationResult::MutateFrozen + && context.hoisted_context_declarations.contains_key(&decl_id) + { + let variable = match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => { + Some(format!("`{}`", n)) + } + _ => None, + }; + let hoisted_access = context + .hoisted_context_declarations + .get(&decl_id) + .cloned() + .flatten(); + let mut diagnostic = CompilerDiagnostic::new( + ErrorCategory::Immutability, + "Cannot access variable before it is declared", + Some(format!( + "{} is accessed before it is declared, which prevents the earlier \ + access from updating when this value changes over time", + variable.as_deref().unwrap_or("This variable") + )), + ); + if let Some(ref access) = hoisted_access { + if access.loc != value.loc { + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc: access.loc, + message: Some(format!( + "{} accessed before it is declared", + variable.as_deref().unwrap_or("variable") + )), + identifier_name: None, + }, + ); + } + } + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc: value.loc, + message: Some(format!( + "{} is declared here", + variable.as_deref().unwrap_or("variable") + )), + identifier_name: None, + }, + ); + apply_effect( + context, + state, + AliasingEffect::MutateFrozen { + place: value.clone(), + error: diagnostic, + }, + initialized, + effects, + env, + func, + )?; + } else { + let reason_str = get_write_error_reason(&abstract_value); + let variable = match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => format!("`{}`", n), + _ => "value".to_string(), + }; + let mut diagnostic = CompilerDiagnostic::new( + ErrorCategory::Immutability, + "This value cannot be modified", + Some(reason_str), + ); + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc: value.loc, + message: Some(format!("{} cannot be modified", variable)), + identifier_name: None, + }, + ); + + if let AliasingEffect::Mutate { + reason: Some(MutationReason::AssignCurrentProperty), + .. + } = &effect + { + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Hint { + message: "Hint: If this value is a Ref (value returned by \ + `useRef()`), rename the variable to end in \"Ref\"." + .to_string(), + }, + ); + } + + let error_kind = if abstract_value.kind == ValueKind::Frozen { + AliasingEffect::MutateFrozen { + place: value.clone(), + error: diagnostic, + } + } else { + AliasingEffect::MutateGlobal { + place: value.clone(), + error: diagnostic, + } + }; + apply_effect(context, state, error_kind, initialized, effects, env, func)?; + } + } + } + AliasingEffect::Impure { .. } + | AliasingEffect::Render { .. } + | AliasingEffect::MutateFrozen { .. } + | AliasingEffect::MutateGlobal { .. } => { + effects.push(effect.clone()); + } + } + Ok(()) +} + +// ============================================================================= +// computeSignatureForInstruction +// ============================================================================= + +fn compute_signature_for_instruction( + context: &mut Context, + env: &Environment, + instr: &react_compiler_hir::Instruction, + _func: &HirFunction, +) -> InstructionSignature { + let lvalue = &instr.lvalue; + let value = &instr.value; + let mut effects: Vec = Vec::new(); + + match value { + InstructionValue::ArrayExpression { elements, .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }); + for element in elements { + match element { + react_compiler_hir::ArrayElement::Place(p) => { + effects.push(AliasingEffect::Capture { + from: p.clone(), + into: lvalue.clone(), + }); + } + react_compiler_hir::ArrayElement::Spread(s) => { + let ty = &env.types + [env.identifiers[s.place.identifier.0 as usize].type_.0 as usize]; + if let Some(mutate_iter) = conditionally_mutate_iterator(&s.place, ty) { + effects.push(mutate_iter); + } + effects.push(AliasingEffect::Capture { + from: s.place.clone(), + into: lvalue.clone(), + }); + } + react_compiler_hir::ArrayElement::Hole => {} + } + } + } + InstructionValue::ObjectExpression { properties, .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }); + for property in properties { + match property { + react_compiler_hir::ObjectPropertyOrSpread::Property(p) => { + effects.push(AliasingEffect::Capture { + from: p.place.clone(), + into: lvalue.clone(), + }); + } + react_compiler_hir::ObjectPropertyOrSpread::Spread(s) => { + effects.push(AliasingEffect::Capture { + from: s.place.clone(), + into: lvalue.clone(), + }); + } + } + } + } + InstructionValue::Await { + value: await_value, .. + } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }); + effects.push(AliasingEffect::MutateTransitiveConditionally { + value: await_value.clone(), + }); + effects.push(AliasingEffect::Capture { + from: await_value.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::NewExpression { callee, args, loc } => { + let sig = get_function_call_signature(env, callee.identifier) + .ok() + .flatten(); + effects.push(AliasingEffect::Apply { + receiver: callee.clone(), + function: callee.clone(), + mutates_function: false, + args: args.iter().map(place_or_spread_to_hole).collect(), + into: lvalue.clone(), + signature: sig, + loc: *loc, + }); + } + InstructionValue::CallExpression { callee, args, loc } => { + let sig = get_function_call_signature(env, callee.identifier) + .ok() + .flatten(); + effects.push(AliasingEffect::Apply { + receiver: callee.clone(), + function: callee.clone(), + mutates_function: true, + args: args.iter().map(place_or_spread_to_hole).collect(), + into: lvalue.clone(), + signature: sig, + loc: *loc, + }); + } + InstructionValue::MethodCall { + receiver, + property, + args, + loc, + } => { + let sig = get_function_call_signature(env, property.identifier) + .ok() + .flatten(); + effects.push(AliasingEffect::Apply { + receiver: receiver.clone(), + function: property.clone(), + mutates_function: false, + args: args.iter().map(place_or_spread_to_hole).collect(), + into: lvalue.clone(), + signature: sig, + loc: *loc, + }); + } + InstructionValue::PropertyDelete { object, .. } + | InstructionValue::ComputedDelete { object, .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + effects.push(AliasingEffect::Mutate { + value: object.clone(), + reason: None, + }); + } + InstructionValue::PropertyLoad { object, .. } + | InstructionValue::ComputedLoad { object, .. } => { + let ty = &env.types[env.identifiers[lvalue.identifier.0 as usize].type_.0 as usize]; + if react_compiler_hir::is_primitive_type(ty) { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } else { + effects.push(AliasingEffect::CreateFrom { + from: object.clone(), + into: lvalue.clone(), + }); + } + } + InstructionValue::PropertyStore { + object, + property, + value: store_value, + .. + } => { + let mutation_reason: Option = { + let obj_ty = + &env.types[env.identifiers[object.identifier.0 as usize].type_.0 as usize]; + if let react_compiler_hir::PropertyLiteral::String(prop_name) = property { + if prop_name == "current" && matches!(obj_ty, Type::TypeVar { .. }) { + Some(MutationReason::AssignCurrentProperty) + } else { + None + } + } else { + None + } + }; + effects.push(AliasingEffect::Mutate { + value: object.clone(), + reason: mutation_reason, + }); + effects.push(AliasingEffect::Capture { + from: store_value.clone(), + into: object.clone(), + }); + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + InstructionValue::ComputedStore { + object, + value: store_value, + .. + } => { + effects.push(AliasingEffect::Mutate { + value: object.clone(), + reason: None, + }); + effects.push(AliasingEffect::Capture { + from: store_value.clone(), + into: object.clone(), + }); + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + let captures: Vec = inner_func + .context + .iter() + .filter(|operand| operand.effect == Effect::Capture) + .cloned() + .collect(); + effects.push(AliasingEffect::CreateFunction { + into: lvalue.clone(), + function_id: lowered_func.func, + captures, + }); + } + InstructionValue::GetIterator { collection, .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }); + let ty = &env.types[env.identifiers[collection.identifier.0 as usize].type_.0 as usize]; + if is_builtin_collection_type(ty) { + effects.push(AliasingEffect::Capture { + from: collection.clone(), + into: lvalue.clone(), + }); + } else { + effects.push(AliasingEffect::Alias { + from: collection.clone(), + into: lvalue.clone(), + }); + effects.push(AliasingEffect::MutateTransitiveConditionally { + value: collection.clone(), + }); + } + } + InstructionValue::IteratorNext { + iterator, + collection, + .. + } => { + effects.push(AliasingEffect::MutateConditionally { + value: iterator.clone(), + }); + effects.push(AliasingEffect::CreateFrom { + from: collection.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::NextPropertyOf { .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + InstructionValue::JsxExpression { + tag, + props, + children, + .. + } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Frozen, + reason: ValueReason::JsxCaptured, + }); + for operand in visitors::each_instruction_value_operand(value, env) { + effects.push(AliasingEffect::Freeze { + value: operand.clone(), + reason: ValueReason::JsxCaptured, + }); + effects.push(AliasingEffect::Capture { + from: operand.clone(), + into: lvalue.clone(), + }); + } + if let JsxTag::Place(tag_place) = tag { + effects.push(AliasingEffect::Render { + place: tag_place.clone(), + }); + } + if let Some(ch) = children { + for child in ch { + effects.push(AliasingEffect::Render { + place: child.clone(), + }); + } + } + for prop in props { + if let react_compiler_hir::JsxAttribute::Attribute { + place: prop_place, .. + } = prop + { + let prop_ty = &env.types + [env.identifiers[prop_place.identifier.0 as usize].type_.0 as usize]; + if let Type::Function { return_type, .. } = prop_ty { + if react_compiler_hir::is_jsx_type(return_type) + || is_phi_with_jsx(return_type) + { + effects.push(AliasingEffect::Render { + place: prop_place.clone(), + }); + } + } + } + } + } + InstructionValue::JsxFragment { children: _, .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Frozen, + reason: ValueReason::JsxCaptured, + }); + for operand in visitors::each_instruction_value_operand(value, env) { + effects.push(AliasingEffect::Freeze { + value: operand.clone(), + reason: ValueReason::JsxCaptured, + }); + effects.push(AliasingEffect::Capture { + from: operand.clone(), + into: lvalue.clone(), + }); + } + } + InstructionValue::DeclareLocal { lvalue: dl, .. } => { + effects.push(AliasingEffect::Create { + into: dl.place.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + InstructionValue::Destructure { + lvalue: dl, + value: dest_value, + .. + } => { + for pat_item in each_pattern_items(&dl.pattern) { + match pat_item { + PatternItem::Place(place) => { + let ty = &env.types + [env.identifiers[place.identifier.0 as usize].type_.0 as usize]; + if react_compiler_hir::is_primitive_type(ty) { + effects.push(AliasingEffect::Create { + into: place.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } else { + effects.push(AliasingEffect::CreateFrom { + from: dest_value.clone(), + into: place.clone(), + }); + } + } + PatternItem::Spread(place) => { + let value_kind = if context.non_mutating_spreads.contains(&place.identifier) + { + ValueKind::Frozen + } else { + ValueKind::Mutable + }; + effects.push(AliasingEffect::Create { + into: place.clone(), + reason: ValueReason::Other, + value: value_kind, + }); + effects.push(AliasingEffect::Capture { + from: dest_value.clone(), + into: place.clone(), + }); + } + } + } + effects.push(AliasingEffect::Assign { + from: dest_value.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::LoadContext { place, .. } => { + effects.push(AliasingEffect::CreateFrom { + from: place.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::DeclareContext { lvalue: dcl, .. } => { + let decl_id = env.identifiers[dcl.place.identifier.0 as usize].declaration_id; + let kind = dcl.kind; + if !context.hoisted_context_declarations.contains_key(&decl_id) + || kind == InstructionKind::HoistedConst + || kind == InstructionKind::HoistedFunction + || kind == InstructionKind::HoistedLet + { + effects.push(AliasingEffect::Create { + into: dcl.place.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }); + } else { + effects.push(AliasingEffect::Mutate { + value: dcl.place.clone(), + reason: None, + }); + } + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + InstructionValue::StoreContext { + lvalue: scl, + value: sc_value, + .. + } => { + let decl_id = env.identifiers[scl.place.identifier.0 as usize].declaration_id; + if scl.kind == InstructionKind::Reassign + || context.hoisted_context_declarations.contains_key(&decl_id) + { + effects.push(AliasingEffect::Mutate { + value: scl.place.clone(), + reason: None, + }); + } else { + effects.push(AliasingEffect::Create { + into: scl.place.clone(), + value: ValueKind::Mutable, + reason: ValueReason::Other, + }); + } + effects.push(AliasingEffect::Capture { + from: sc_value.clone(), + into: scl.place.clone(), + }); + effects.push(AliasingEffect::Assign { + from: sc_value.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::LoadLocal { place, .. } => { + effects.push(AliasingEffect::Assign { + from: place.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::StoreLocal { + lvalue: sl, + value: sl_value, + .. + } => { + effects.push(AliasingEffect::Assign { + from: sl_value.clone(), + into: sl.place.clone(), + }); + effects.push(AliasingEffect::Assign { + from: sl_value.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::PostfixUpdate { + lvalue: pf_lvalue, .. + } + | InstructionValue::PrefixUpdate { + lvalue: pf_lvalue, .. + } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + effects.push(AliasingEffect::Create { + into: pf_lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + InstructionValue::StoreGlobal { + name, + value: sg_value, + loc: _, + .. + } => { + let variable = format!("`{}`", name); + let mut diagnostic = CompilerDiagnostic::new( + ErrorCategory::Globals, + "Cannot reassign variables declared outside of the component/hook", + Some(format!( + "Variable {} is declared outside of the component/hook. Reassigning this value during render is a form of side effect, which can cause unpredictable behavior depending on when the component happens to re-render. If this variable is used in rendering, use useState instead. Otherwise, consider updating it in an effect. (https://react.dev/reference/rules/components-and-hooks-must-be-pure#side-effects-must-run-outside-of-render)", + variable + )), + ); + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc: instr.loc, + message: Some(format!("{} cannot be reassigned", variable)), + identifier_name: None, + }, + ); + effects.push(AliasingEffect::MutateGlobal { + place: sg_value.clone(), + error: diagnostic, + }); + effects.push(AliasingEffect::Assign { + from: sg_value.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::TypeCastExpression { + value: tc_value, .. + } => { + effects.push(AliasingEffect::Assign { + from: tc_value.clone(), + into: lvalue.clone(), + }); + } + InstructionValue::LoadGlobal { .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Global, + reason: ValueReason::Global, + }); + } + InstructionValue::StartMemoize { .. } | InstructionValue::FinishMemoize { .. } => { + if env.config.enable_preserve_existing_memoization_guarantees { + for operand in visitors::each_instruction_value_operand(value, env) { + effects.push(AliasingEffect::Freeze { + value: operand.clone(), + reason: ValueReason::HookCaptured, + }); + } + } + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + // All primitive-creating instructions + InstructionValue::TaggedTemplateExpression { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::UnaryExpression { .. } + | InstructionValue::UnsupportedNode { .. } => { + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: ValueKind::Primitive, + reason: ValueReason::Other, + }); + } + } + + InstructionSignature { effects } +} + +// ============================================================================= +// Legacy signature support +// ============================================================================= + +fn compute_effects_for_legacy_signature( + state: &InferenceState, + signature: &FunctionSignature, + lvalue: &Place, + receiver: &Place, + args: &[PlaceOrSpreadOrHole], + _loc: Option<&SourceLocation>, + env: &Environment, + function_values: &HashMap, + todo_errors: &mut Vec, +) -> Vec { + let return_value_reason = signature.return_value_reason.unwrap_or(ValueReason::Other); + let mut effects: Vec = Vec::new(); + + effects.push(AliasingEffect::Create { + into: lvalue.clone(), + value: signature.return_value_kind, + reason: return_value_reason, + }); + + if signature.impure && env.config.validate_no_impure_functions_in_render { + let mut diagnostic = CompilerDiagnostic::new( + ErrorCategory::Purity, + "Cannot call impure function during render", + Some(format!( + "{}Calling an impure function can produce unstable results that update unpredictably when the component happens to re-render. (https://react.dev/reference/rules/components-and-hooks-must-be-pure#components-and-hooks-must-be-idempotent)", + if let Some(ref name) = signature.canonical_name { + format!("`{}` is an impure function. ", name) + } else { + String::new() + } + )), + ); + diagnostic.details.push( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc: _loc.copied(), + message: Some("Cannot call impure function".to_string()), + identifier_name: None, + }, + ); + effects.push(AliasingEffect::Impure { + place: receiver.clone(), + error: diagnostic, + }); + } + + // TODO: check signature.known_incompatible and throw (TS line 2351-2370) + // This requires threading Result through apply_effect/apply_signature. + + // If the function is mutable only if operands are mutable, and all + // arguments are immutable/non-mutating, short-circuit with simple aliasing. + if signature.mutable_only_if_operands_are_mutable + && are_arguments_immutable_and_non_mutating(state, args, env, function_values) + { + effects.push(AliasingEffect::Alias { + from: receiver.clone(), + into: lvalue.clone(), + }); + for arg in args { + match arg { + PlaceOrSpreadOrHole::Hole => continue, + PlaceOrSpreadOrHole::Place(place) + | PlaceOrSpreadOrHole::Spread(react_compiler_hir::SpreadPattern { place }) => { + effects.push(AliasingEffect::ImmutableCapture { + from: place.clone(), + into: lvalue.clone(), + }); + } + } + } + return effects; + } + + let mut stores: Vec = Vec::new(); + let mut captures: Vec = Vec::new(); + + let mut visit = |place: &Place, effect: Effect, effects: &mut Vec| match effect + { + Effect::Store => { + effects.push(AliasingEffect::Mutate { + value: place.clone(), + reason: None, + }); + stores.push(place.clone()); + } + Effect::Capture => { + captures.push(place.clone()); + } + Effect::ConditionallyMutate => { + effects.push(AliasingEffect::MutateTransitiveConditionally { + value: place.clone(), + }); + } + Effect::ConditionallyMutateIterator => { + let ty = &env.types[env.identifiers[place.identifier.0 as usize].type_.0 as usize]; + if let Some(mutate_iter) = conditionally_mutate_iterator(place, ty) { + effects.push(mutate_iter); + } + effects.push(AliasingEffect::Capture { + from: place.clone(), + into: lvalue.clone(), + }); + } + Effect::Freeze => { + effects.push(AliasingEffect::Freeze { + value: place.clone(), + reason: return_value_reason, + }); + } + Effect::Mutate => { + effects.push(AliasingEffect::MutateTransitive { + value: place.clone(), + }); + } + Effect::Read => { + effects.push(AliasingEffect::ImmutableCapture { + from: place.clone(), + into: lvalue.clone(), + }); + } + _ => {} + }; + + if signature.callee_effect != Effect::Capture { + effects.push(AliasingEffect::Alias { + from: receiver.clone(), + into: lvalue.clone(), + }); + } + + visit(receiver, signature.callee_effect, &mut effects); + for (i, arg) in args.iter().enumerate() { + match arg { + PlaceOrSpreadOrHole::Hole => continue, + PlaceOrSpreadOrHole::Place(place) + | PlaceOrSpreadOrHole::Spread(react_compiler_hir::SpreadPattern { place }) => { + let is_spread = matches!(arg, PlaceOrSpreadOrHole::Spread(_)); + let sig_effect = if !is_spread && i < signature.positional_params.len() { + signature.positional_params[i] + } else { + signature.rest_param.unwrap_or(Effect::ConditionallyMutate) + }; + let (effect, err_detail) = get_argument_effect(sig_effect, is_spread, place.loc); + if let Some(d) = err_detail { + todo_errors.push(d); + } + visit(place, effect, &mut effects); + } + } + } + + if !captures.is_empty() { + if stores.is_empty() { + for capture in &captures { + effects.push(AliasingEffect::Alias { + from: capture.clone(), + into: lvalue.clone(), + }); + } + } else { + for capture in &captures { + for store in &stores { + effects.push(AliasingEffect::Capture { + from: capture.clone(), + into: store.clone(), + }); + } + } + } + } + + effects +} + +fn get_argument_effect( + sig_effect: Effect, + is_spread: bool, + spread_loc: Option, +) -> ( + Effect, + Option, +) { + if !is_spread { + (sig_effect, None) + } else if sig_effect == Effect::Mutate || sig_effect == Effect::ConditionallyMutate { + (sig_effect, None) + } else { + // Spread with Freeze effect is unsupported for hook arguments + // (matches TS CompilerError.throwTodo) + let detail = if sig_effect == Effect::Freeze { + Some(react_compiler_diagnostics::CompilerErrorDetail { + reason: "Support spread syntax for hook arguments".to_string(), + description: None, + category: ErrorCategory::Todo, + loc: spread_loc, + suggestions: None, + }) + } else { + None + }; + (Effect::ConditionallyMutateIterator, detail) + } +} + +/// Returns true if all of the arguments are both non-mutable (immutable or +/// frozen) _and_ are not functions which might mutate their arguments. +/// +/// Corresponds to TS `areArgumentsImmutableAndNonMutating`. +fn are_arguments_immutable_and_non_mutating( + state: &InferenceState, + args: &[PlaceOrSpreadOrHole], + env: &Environment, + function_values: &HashMap, +) -> bool { + for arg in args { + match arg { + PlaceOrSpreadOrHole::Hole => continue, + PlaceOrSpreadOrHole::Place(place) + | PlaceOrSpreadOrHole::Spread(react_compiler_hir::SpreadPattern { place }) => { + // Check if it's a function type with a known signature + let is_place = matches!(arg, PlaceOrSpreadOrHole::Place(_)); + if is_place { + let ty = + &env.types[env.identifiers[place.identifier.0 as usize].type_.0 as usize]; + if let Type::Function { .. } = ty { + let fn_shape = env.get_function_signature(ty).ok().flatten(); + if let Some(fn_sig) = fn_shape { + let has_mutable_param = fn_sig + .positional_params + .iter() + .any(|e| is_known_mutable_effect(*e)); + let has_mutable_rest = fn_sig + .rest_param + .map_or(false, |e| is_known_mutable_effect(e)); + return !has_mutable_param && !has_mutable_rest; + } + } + } + + let kind = state.kind(place.identifier); + match kind.kind { + ValueKind::Primitive | ValueKind::Frozen => { + // Immutable values are ok, continue checking + } + _ => { + return false; + } + } + + // Check if any value for this place is a function expression + // that mutates its parameters (TS lines 2545-2557) + let value_ids = state.values_for(place.identifier); + for vid in &value_ids { + if let Some(&func_id) = function_values.get(vid) { + let inner_func = &env.functions[func_id.0 as usize]; + let mutates_params = inner_func.params.iter().any(|param| { + let param_id = match param { + ParamPattern::Place(p) => p.identifier, + ParamPattern::Spread(s) => s.place.identifier, + }; + let ident = &env.identifiers[param_id.0 as usize]; + ident.mutable_range.end.0 > ident.mutable_range.start.0 + 1 + }); + if mutates_params { + return false; + } + } + } + } + } + } + true +} + +fn is_known_mutable_effect(effect: Effect) -> bool { + matches!( + effect, + Effect::Store + | Effect::Mutate + | Effect::ConditionallyMutate + | Effect::ConditionallyMutateIterator + ) +} + +// ============================================================================= +// Aliasing signature config support (new-style signatures) +// ============================================================================= + +fn compute_effects_for_aliasing_signature_config( + env: &mut Environment, + config: &react_compiler_hir::type_config::AliasingSignatureConfig, + lvalue: &Place, + receiver: &Place, + args: &[PlaceOrSpreadOrHole], + context: &[Place], + _loc: Option<&SourceLocation>, + temp_cache: &mut HashMap<(IdentifierId, String), Place>, +) -> Result>, CompilerDiagnostic> { + // Build substitutions from config strings to places + let mut substitutions: HashMap> = HashMap::new(); + substitutions.insert(config.receiver.clone(), vec![receiver.clone()]); + substitutions.insert(config.returns.clone(), vec![lvalue.clone()]); + + let mut mutable_spreads: HashSet = HashSet::new(); + + for (i, arg) in args.iter().enumerate() { + match arg { + PlaceOrSpreadOrHole::Hole => continue, + PlaceOrSpreadOrHole::Place(place) + | PlaceOrSpreadOrHole::Spread(react_compiler_hir::SpreadPattern { place }) => { + if i < config.params.len() && !matches!(arg, PlaceOrSpreadOrHole::Spread(_)) { + substitutions.insert(config.params[i].clone(), vec![place.clone()]); + } else if let Some(ref rest) = config.rest { + substitutions + .entry(rest.clone()) + .or_default() + .push(place.clone()); + } else { + return Ok(None); + } + + if matches!(arg, PlaceOrSpreadOrHole::Spread(_)) { + let ty = + &env.types[env.identifiers[place.identifier.0 as usize].type_.0 as usize]; + let mutate_iterator = conditionally_mutate_iterator(place, ty); + if mutate_iterator.is_some() { + mutable_spreads.insert(place.identifier); + } + } + } + } + } + + for operand in context { + let ident = &env.identifiers[operand.identifier.0 as usize]; + if let Some(ref name) = ident.name { + substitutions.insert(format!("@{}", name.value()), vec![operand.clone()]); + } + } + + // Create temporaries (cached by lvalue + temp_name to be stable across fixpoint + // iterations) + for temp_name in &config.temporaries { + let cache_key = (lvalue.identifier, temp_name.clone()); + let temp_place = temp_cache + .entry(cache_key) + .or_insert_with(|| create_temp_place(env, receiver.loc)) + .clone(); + substitutions.insert(temp_name.clone(), vec![temp_place]); + } + + let mut effects: Vec = Vec::new(); + + for eff_config in &config.effects { + match eff_config { + react_compiler_hir::type_config::AliasingEffectConfig::Freeze { value, reason } => { + let values = substitutions.get(value).cloned().unwrap_or_default(); + for v in values { + if mutable_spreads.contains(&v.identifier) { + return Err(CompilerDiagnostic::todo( + "Support spread syntax for hook arguments", + v.loc, + )); + } + effects.push(AliasingEffect::Freeze { value: v, reason: *reason }); + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Create { into, value, reason } => { + let intos = substitutions.get(into).cloned().unwrap_or_default(); + for v in intos { + effects.push(AliasingEffect::Create { into: v, value: *value, reason: *reason }); + } + } + react_compiler_hir::type_config::AliasingEffectConfig::CreateFrom { from, into } => { + let froms = substitutions.get(from).cloned().unwrap_or_default(); + let intos = substitutions.get(into).cloned().unwrap_or_default(); + for f in &froms { + for t in &intos { + effects.push(AliasingEffect::CreateFrom { from: f.clone(), into: t.clone() }); + } + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Assign { from, into } => { + let froms = substitutions.get(from).cloned().unwrap_or_default(); + let intos = substitutions.get(into).cloned().unwrap_or_default(); + for f in &froms { + for t in &intos { + effects.push(AliasingEffect::Assign { from: f.clone(), into: t.clone() }); + } + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Alias { from, into } => { + let froms = substitutions.get(from).cloned().unwrap_or_default(); + let intos = substitutions.get(into).cloned().unwrap_or_default(); + for f in &froms { + for t in &intos { + effects.push(AliasingEffect::Alias { from: f.clone(), into: t.clone() }); + } + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Capture { from, into } => { + let froms = substitutions.get(from).cloned().unwrap_or_default(); + let intos = substitutions.get(into).cloned().unwrap_or_default(); + for f in &froms { + for t in &intos { + effects.push(AliasingEffect::Capture { from: f.clone(), into: t.clone() }); + } + } + } + react_compiler_hir::type_config::AliasingEffectConfig::ImmutableCapture { from, into } => { + let froms = substitutions.get(from).cloned().unwrap_or_default(); + let intos = substitutions.get(into).cloned().unwrap_or_default(); + for f in &froms { + for t in &intos { + effects.push(AliasingEffect::ImmutableCapture { from: f.clone(), into: t.clone() }); + } + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Impure { place } => { + let values = substitutions.get(place).cloned().unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::Impure { + place: v, + error: CompilerDiagnostic::new(ErrorCategory::Purity, "Impure function call", None), + }); + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Mutate { value } => { + let values = substitutions.get(value).cloned().unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::Mutate { value: v, reason: None }); + } + } + react_compiler_hir::type_config::AliasingEffectConfig::MutateTransitiveConditionally { value } => { + let values = substitutions.get(value).cloned().unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::MutateTransitiveConditionally { value: v }); + } + } + react_compiler_hir::type_config::AliasingEffectConfig::Apply { receiver: r, function: f, mutates_function, args: a, into: i } => { + let recv = substitutions.get(r).and_then(|v| v.first()).cloned(); + let func = substitutions.get(f).and_then(|v| v.first()).cloned(); + let into = substitutions.get(i).and_then(|v| v.first()).cloned(); + if let (Some(recv), Some(func), Some(into)) = (recv, func, into) { + let mut apply_args: Vec = Vec::new(); + for arg in a { + match arg { + react_compiler_hir::type_config::ApplyArgConfig::Hole { .. } => { + apply_args.push(PlaceOrSpreadOrHole::Hole); + } + react_compiler_hir::type_config::ApplyArgConfig::Place(name) => { + if let Some(places) = substitutions.get(name) { + if let Some(p) = places.first() { + apply_args.push(PlaceOrSpreadOrHole::Place(p.clone())); + } + } + } + react_compiler_hir::type_config::ApplyArgConfig::Spread { place: name, .. } => { + if let Some(places) = substitutions.get(name) { + if let Some(p) = places.first() { + apply_args.push(PlaceOrSpreadOrHole::Spread(react_compiler_hir::SpreadPattern { place: p.clone() })); + } + } + } + } + } + effects.push(AliasingEffect::Apply { + receiver: recv, + function: func, + mutates_function: *mutates_function, + args: apply_args, + into, + signature: None, + loc: _loc.copied(), + }); + } else { + return Ok(None); + } + } + } + } + + Ok(Some(effects)) +} + +// ============================================================================= +// Function expression signature building +// ============================================================================= + +/// Build an AliasingSignature from a function expression's +/// params/returns/aliasing effects. Corresponds to TS +/// `buildSignatureFromFunctionExpression`. +fn build_signature_from_function_expression( + env: &mut Environment, + func_id: FunctionId, +) -> AliasingSignature { + let inner_func = &env.functions[func_id.0 as usize]; + let mut params: Vec = Vec::new(); + let mut rest: Option = None; + for param in &inner_func.params { + match param { + ParamPattern::Place(p) => params.push(p.identifier), + ParamPattern::Spread(s) => rest = Some(s.place.identifier), + } + } + let returns = inner_func.returns.identifier; + let aliasing_effects = inner_func.aliasing_effects.clone().unwrap_or_default(); + let loc = inner_func.loc; + + if rest.is_none() { + let temp = create_temp_place(env, loc); + rest = Some(temp.identifier); + } + + AliasingSignature { + receiver: IdentifierId(0), + params, + rest, + returns, + effects: aliasing_effects, + temporaries: Vec::new(), + } +} + +/// Compute effects by substituting an AliasingSignature (IdentifierId-based) +/// with actual arguments. Corresponds to TS `computeEffectsForSignature`. +fn compute_effects_for_aliasing_signature( + env: &mut Environment, + signature: &AliasingSignature, + lvalue: &Place, + receiver: &Place, + args: &[PlaceOrSpreadOrHole], + context: &[Place], + _loc: Option<&SourceLocation>, +) -> Result>, CompilerDiagnostic> { + if signature.params.len() > args.len() + || (args.len() > signature.params.len() && signature.rest.is_none()) + { + return Ok(None); + } + + let mut mutable_spreads: HashSet = HashSet::new(); + let mut substitutions: HashMap> = HashMap::new(); + substitutions.insert(signature.receiver, vec![receiver.clone()]); + substitutions.insert(signature.returns, vec![lvalue.clone()]); + + for (i, arg) in args.iter().enumerate() { + match arg { + PlaceOrSpreadOrHole::Hole => continue, + PlaceOrSpreadOrHole::Place(place) + | PlaceOrSpreadOrHole::Spread(react_compiler_hir::SpreadPattern { place }) => { + let is_spread = matches!(arg, PlaceOrSpreadOrHole::Spread(_)); + if !is_spread && i < signature.params.len() { + substitutions.insert(signature.params[i], vec![place.clone()]); + } else if let Some(rest_id) = signature.rest { + substitutions + .entry(rest_id) + .or_default() + .push(place.clone()); + } else { + return Ok(None); + } + + if is_spread { + let ty = + &env.types[env.identifiers[place.identifier.0 as usize].type_.0 as usize]; + let mutate_iterator = conditionally_mutate_iterator(place, ty); + if mutate_iterator.is_some() { + mutable_spreads.insert(place.identifier); + } + } + } + } + } + + // Add context variable substitutions (identity mapping) + for operand in context { + substitutions.insert(operand.identifier, vec![operand.clone()]); + } + + // Create temporaries + for temp in &signature.temporaries { + let temp_place = create_temp_place(env, receiver.loc); + substitutions.insert(temp.identifier, vec![temp_place]); + } + + let mut effects: Vec = Vec::new(); + + for eff in &signature.effects { + match eff { + AliasingEffect::MaybeAlias { from, into } + | AliasingEffect::Assign { from, into } + | AliasingEffect::ImmutableCapture { from, into } + | AliasingEffect::Alias { from, into } + | AliasingEffect::CreateFrom { from, into } + | AliasingEffect::Capture { from, into } => { + let from_places = substitutions + .get(&from.identifier) + .cloned() + .unwrap_or_default(); + let to_places = substitutions + .get(&into.identifier) + .cloned() + .unwrap_or_default(); + for f in &from_places { + for t in &to_places { + effects.push(match eff { + AliasingEffect::MaybeAlias { .. } => AliasingEffect::MaybeAlias { + from: f.clone(), + into: t.clone(), + }, + AliasingEffect::Assign { .. } => AliasingEffect::Assign { + from: f.clone(), + into: t.clone(), + }, + AliasingEffect::ImmutableCapture { .. } => { + AliasingEffect::ImmutableCapture { + from: f.clone(), + into: t.clone(), + } + } + AliasingEffect::Alias { .. } => AliasingEffect::Alias { + from: f.clone(), + into: t.clone(), + }, + AliasingEffect::CreateFrom { .. } => AliasingEffect::CreateFrom { + from: f.clone(), + into: t.clone(), + }, + AliasingEffect::Capture { .. } => AliasingEffect::Capture { + from: f.clone(), + into: t.clone(), + }, + _ => unreachable!(), + }); + } + } + } + AliasingEffect::Impure { place, error } => { + let values = substitutions + .get(&place.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::Impure { + place: v, + error: error.clone(), + }); + } + } + AliasingEffect::MutateFrozen { place, error } => { + let values = substitutions + .get(&place.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::MutateFrozen { + place: v, + error: error.clone(), + }); + } + } + AliasingEffect::MutateGlobal { place, error } => { + let values = substitutions + .get(&place.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::MutateGlobal { + place: v, + error: error.clone(), + }); + } + } + AliasingEffect::Render { place } => { + let values = substitutions + .get(&place.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::Render { place: v }); + } + } + AliasingEffect::Mutate { value, reason } => { + let values = substitutions + .get(&value.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::Mutate { + value: v, + reason: reason.clone(), + }); + } + } + AliasingEffect::MutateConditionally { value } => { + let values = substitutions + .get(&value.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::MutateConditionally { value: v }); + } + } + AliasingEffect::MutateTransitive { value } => { + let values = substitutions + .get(&value.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::MutateTransitive { value: v }); + } + } + AliasingEffect::MutateTransitiveConditionally { value } => { + let values = substitutions + .get(&value.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + effects.push(AliasingEffect::MutateTransitiveConditionally { value: v }); + } + } + AliasingEffect::Freeze { value, reason } => { + let values = substitutions + .get(&value.identifier) + .cloned() + .unwrap_or_default(); + for v in values { + if mutable_spreads.contains(&v.identifier) { + return Err(CompilerDiagnostic::todo( + "Support spread syntax for hook arguments", + v.loc, + )); + } + effects.push(AliasingEffect::Freeze { + value: v, + reason: *reason, + }); + } + } + AliasingEffect::Create { + into, + value, + reason, + } => { + let intos = substitutions + .get(&into.identifier) + .cloned() + .unwrap_or_default(); + for v in intos { + effects.push(AliasingEffect::Create { + into: v, + value: *value, + reason: *reason, + }); + } + } + AliasingEffect::Apply { + receiver: r, + function: f, + mutates_function: mf, + args: a, + into: i, + signature: s, + loc: _l, + } => { + let recv = substitutions + .get(&r.identifier) + .and_then(|v| v.first()) + .cloned(); + let func = substitutions + .get(&f.identifier) + .and_then(|v| v.first()) + .cloned(); + let apply_into = substitutions + .get(&i.identifier) + .and_then(|v| v.first()) + .cloned(); + if let (Some(recv), Some(func), Some(apply_into)) = (recv, func, apply_into) { + let mut apply_args: Vec = Vec::new(); + for arg in a { + match arg { + PlaceOrSpreadOrHole::Hole => apply_args.push(PlaceOrSpreadOrHole::Hole), + PlaceOrSpreadOrHole::Place(p) => { + if let Some(places) = substitutions.get(&p.identifier) { + if let Some(place) = places.first() { + apply_args.push(PlaceOrSpreadOrHole::Place(place.clone())); + } + } + } + PlaceOrSpreadOrHole::Spread(sp) => { + if let Some(places) = substitutions.get(&sp.place.identifier) { + if let Some(place) = places.first() { + apply_args.push(PlaceOrSpreadOrHole::Spread( + react_compiler_hir::SpreadPattern { + place: place.clone(), + }, + )); + } + } + } + } + } + effects.push(AliasingEffect::Apply { + receiver: recv, + function: func, + mutates_function: *mf, + args: apply_args, + into: apply_into, + signature: s.clone(), + loc: _loc.copied(), + }); + } else { + return Ok(None); + } + } + AliasingEffect::CreateFunction { .. } => { + // Not supported in signature substitution + return Ok(None); + } + } + } + + Ok(Some(effects)) +} + +// ============================================================================= +// Helpers +// ============================================================================= + +/// Select the primary (most specific) reason from a set of reasons. +/// TS uses `[...set][0]` which returns the first-inserted element; +/// since the primary reason is always inserted first, this effectively +/// picks the most specific non-Other reason. We replicate this by +/// preferring any non-Other reason over Other. +fn primary_reason(reasons: &HashSet) -> ValueReason { + for &r in reasons { + if r != ValueReason::Other { + return r; + } + } + ValueReason::Other +} + +fn get_write_error_reason(abstract_value: &AbstractValue) -> String { + if abstract_value.reason.contains(&ValueReason::Global) { + "Modifying a variable defined outside a component or hook is not allowed. Consider using \ + an effect" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::JsxCaptured) { + "Modifying a value used previously in JSX is not allowed. Consider moving the modification \ + before the JSX" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::Context) { + "Modifying a value returned from 'useContext()' is not allowed.".to_string() + } else if abstract_value + .reason + .contains(&ValueReason::KnownReturnSignature) + { + "Modifying a value returned from a function whose return value should not be mutated" + .to_string() + } else if abstract_value + .reason + .contains(&ValueReason::ReactiveFunctionArgument) + { + "Modifying component props or hook arguments is not allowed. Consider using a local \ + variable instead" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::State) { + "Modifying a value returned from 'useState()', which should not be modified directly. Use \ + the setter function to update instead" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::ReducerState) { + "Modifying a value returned from 'useReducer()', which should not be modified directly. \ + Use the dispatch function to update instead" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::Effect) { + "Modifying a value used previously in an effect function or as an effect dependency is not \ + allowed. Consider moving the modification before calling useEffect()" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::HookCaptured) { + "Modifying a value previously passed as an argument to a hook is not allowed. Consider \ + moving the modification before calling the hook" + .to_string() + } else if abstract_value.reason.contains(&ValueReason::HookReturn) { + "Modifying a value returned from a hook is not allowed. Consider moving the modification \ + into the hook where the value is constructed" + .to_string() + } else { + "This modifies a variable that React considers immutable".to_string() + } +} + +fn conditionally_mutate_iterator(place: &Place, ty: &Type) -> Option { + if !is_builtin_collection_type(ty) { + Some(AliasingEffect::MutateTransitiveConditionally { + value: place.clone(), + }) + } else { + None + } +} + +fn is_builtin_collection_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } + if id == BUILT_IN_ARRAY_ID || id == BUILT_IN_SET_ID || id == BUILT_IN_MAP_ID + ) +} + +fn get_function_call_signature( + env: &Environment, + callee_id: IdentifierId, +) -> Result, CompilerDiagnostic> { + let ty = &env.types[env.identifiers[callee_id.0 as usize].type_.0 as usize]; + Ok(env.get_function_signature(ty)?.cloned()) +} + +fn is_ref_or_ref_value_for_id(env: &Environment, id: IdentifierId) -> bool { + let ty = &env.types[env.identifiers[id.0 as usize].type_.0 as usize]; + react_compiler_hir::is_ref_or_ref_value(ty) +} + +fn get_hook_kind_for_type<'a>( + env: &'a Environment, + ty: &Type, +) -> Result, CompilerDiagnostic> { + env.get_hook_kind_for_type(ty) +} + +/// Format a Type for printPlace-style output, matching TS's `printType()`. +fn format_type_for_print(ty: &Type) -> String { + match ty { + Type::Primitive => String::new(), + Type::Function { + shape_id, + return_type, + .. + } => { + if let Some(sid) = shape_id { + let ret = format_type_for_print(return_type); + if ret.is_empty() { + format!(":TFunction<{}>()", sid) + } else { + format!(":TFunction<{}>(): {}", sid, ret) + } + } else { + ":TFunction".to_string() + } + } + Type::Object { shape_id } => { + if let Some(sid) = shape_id { + format!(":TObject<{}>", sid) + } else { + ":TObject".to_string() + } + } + Type::Poly => ":TPoly".to_string(), + Type::Phi { .. } => ":TPhi".to_string(), + Type::Property { .. } => ":TProperty".to_string(), + Type::TypeVar { .. } => String::new(), + Type::ObjectMethod => ":TObjectMethod".to_string(), + } +} + +fn is_phi_with_jsx(ty: &Type) -> bool { + if let Type::Phi { operands } = ty { + operands + .iter() + .any(|op| react_compiler_hir::is_jsx_type(op)) + } else { + false + } +} + +fn place_or_spread_to_hole(pos: &PlaceOrSpread) -> PlaceOrSpreadOrHole { + match pos { + PlaceOrSpread::Place(p) => PlaceOrSpreadOrHole::Place(p.clone()), + PlaceOrSpread::Spread(s) => PlaceOrSpreadOrHole::Spread(s.clone()), + } +} + +use react_compiler_hir::JsxTag; + +fn build_apply_operands( + receiver: &Place, + function: &Place, + args: &[PlaceOrSpreadOrHole], +) -> Vec<(Place, bool, bool)> { + let mut result = vec![ + (receiver.clone(), false, false), + (function.clone(), true, false), + ]; + for arg in args { + match arg { + PlaceOrSpreadOrHole::Hole => continue, + PlaceOrSpreadOrHole::Place(p) => result.push((p.clone(), false, false)), + PlaceOrSpreadOrHole::Spread(s) => result.push((s.place.clone(), false, true)), + } + } + result +} + +fn create_temp_place(env: &mut Environment, loc: Option) -> Place { + let id = env.next_identifier_id(); + env.identifiers[id.0 as usize].loc = loc; + Place { + identifier: id, + effect: Effect::Unknown, + reactive: false, + loc, + } +} + +// ============================================================================= +// Terminal successor helper +// ============================================================================= + +/// Returns the successor blocks used for traversal in mutation/aliasing +/// inference. +/// +/// Matches the TS `eachTerminalSuccessor` which yields standard control-flow +/// successors but NOT pseudo-successors (fallthroughs). Fallthroughs for +/// Logical/Ternary/Optional and Try/Scope/PrunedScope are reached naturally +/// via the block iteration order (blocks are stored in topological order). +fn terminal_successors(terminal: &react_compiler_hir::Terminal) -> Vec { + use react_compiler_hir::Terminal; + match terminal { + Terminal::Goto { block, .. } => vec![*block], + Terminal::If { + consequent, + alternate, + .. + } => vec![*consequent, *alternate], + Terminal::Branch { + consequent, + alternate, + .. + } => vec![*consequent, *alternate], + Terminal::Switch { cases, .. } => cases.iter().map(|c| c.block).collect(), + Terminal::For { init, .. } => vec![*init], + Terminal::ForOf { init, .. } | Terminal::ForIn { init, .. } => vec![*init], + Terminal::DoWhile { loop_block, .. } => vec![*loop_block], + Terminal::While { test, .. } => vec![*test], + Terminal::Return { .. } + | Terminal::Throw { .. } + | Terminal::Unreachable { .. } + | Terminal::Unsupported { .. } => vec![], + Terminal::Try { block, .. } => vec![*block], + Terminal::MaybeThrow { + continuation, + handler, + .. + } => { + let mut v = vec![*continuation]; + if let Some(h) = handler { + v.push(*h); + } + v + } + Terminal::Label { block, .. } | Terminal::Sequence { block, .. } => vec![*block], + Terminal::Logical { test, .. } | Terminal::Ternary { test, .. } => vec![*test], + Terminal::Optional { test, .. } => vec![*test], + Terminal::Scope { block, .. } | Terminal::PrunedScope { block, .. } => vec![*block], + } +} + +/// Pattern item helper for Destructure. +/// +/// NOTE: This cannot use `visitors::each_pattern_operand` because callers need +/// to distinguish Place from Spread elements — Spread elements get different +/// aliasing effects (Create + Capture) vs Place elements (Create or +/// CreateFrom). +enum PatternItem<'a> { + Place(&'a Place), + Spread(&'a Place), +} + +fn each_pattern_items(pattern: &react_compiler_hir::Pattern) -> Vec> { + let mut items = Vec::new(); + match pattern { + react_compiler_hir::Pattern::Array(arr) => { + for el in &arr.items { + match el { + react_compiler_hir::ArrayPatternElement::Place(p) => { + items.push(PatternItem::Place(p)) + } + react_compiler_hir::ArrayPatternElement::Spread(s) => { + items.push(PatternItem::Spread(&s.place)) + } + react_compiler_hir::ArrayPatternElement::Hole => {} + } + } + } + react_compiler_hir::Pattern::Object(obj) => { + for prop in &obj.properties { + match prop { + react_compiler_hir::ObjectPropertyOrSpread::Property(p) => { + items.push(PatternItem::Place(&p.place)) + } + react_compiler_hir::ObjectPropertyOrSpread::Spread(s) => { + items.push(PatternItem::Spread(&s.place)) + } + } + } + } + } + items +} diff --git a/crates/react_compiler_inference/src/infer_mutation_aliasing_ranges.rs b/crates/react_compiler_inference/src/infer_mutation_aliasing_ranges.rs new file mode 100644 index 000000000000..3fb9fc315d15 --- /dev/null +++ b/crates/react_compiler_inference/src/infer_mutation_aliasing_ranges.rs @@ -0,0 +1,1184 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Infers mutable ranges for identifiers and populates Place effects. +//! +//! Ported from TypeScript `src/Inference/InferMutationAliasingRanges.ts`. +//! +//! This pass builds an abstract model of the heap and interprets the effects of +//! the given function in order to determine: +//! - The mutable ranges of all identifiers in the function +//! - The externally-visible effects of the function (mutations of +//! params/context vars, aliasing between params/context-vars/return-value) +//! - The legacy `Effect` to store on each Place + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, + is_jsx_type, is_primitive_type, + type_config::{ValueKind, ValueReason}, + visitors::{ + each_instruction_value_lvalue, for_each_instruction_value_lvalue_mut, + for_each_instruction_value_operand_mut, for_each_terminal_operand_mut, + }, + AliasingEffect, BlockId, Effect, EvaluationOrder, FunctionId, HirFunction, IdentifierId, + InstructionValue, MutationReason, Place, SourceLocation, +}; + +// ============================================================================= +// MutationKind +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[allow(dead_code)] +enum MutationKind { + None = 0, + Conditional = 1, + Definite = 2, +} + +// ============================================================================= +// Node and AliasingState +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum EdgeKind { + Capture, + Alias, + MaybeAlias, +} + +#[derive(Debug, Clone)] +struct Edge { + index: usize, + node: IdentifierId, + kind: EdgeKind, +} + +#[derive(Debug, Clone)] +struct MutationInfo { + kind: MutationKind, + loc: Option, +} + +#[derive(Debug, Clone)] +enum NodeValue { + Object, + Phi, + Function { function_id: FunctionId }, +} + +#[derive(Debug, Clone)] +struct Node { + id: IdentifierId, + created_from: HashMap, + captures: HashMap, + aliases: HashMap, + maybe_aliases: HashMap, + edges: Vec, + transitive: Option, + local: Option, + last_mutated: usize, + mutation_reason: Option, + value: NodeValue, +} + +impl Node { + fn new(id: IdentifierId, value: NodeValue) -> Self { + Node { + id, + created_from: HashMap::new(), + captures: HashMap::new(), + aliases: HashMap::new(), + maybe_aliases: HashMap::new(), + edges: Vec::new(), + transitive: None, + local: None, + last_mutated: 0, + mutation_reason: None, + value, + } + } +} + +struct AliasingState { + nodes: HashMap, +} + +impl AliasingState { + fn new() -> Self { + AliasingState { + nodes: HashMap::new(), + } + } + + fn create(&mut self, place: &Place, value: NodeValue) { + self.nodes + .insert(place.identifier, Node::new(place.identifier, value)); + } + + fn create_from(&mut self, index: usize, from: &Place, into: &Place) { + self.create(into, NodeValue::Object); + let from_id = from.identifier; + let into_id = into.identifier; + // Add forward edge from -> into on the from node + if let Some(from_node) = self.nodes.get_mut(&from_id) { + from_node.edges.push(Edge { + index, + node: into_id, + kind: EdgeKind::Alias, + }); + } + // Add created_from on the into node + if let Some(to_node) = self.nodes.get_mut(&into_id) { + to_node.created_from.entry(from_id).or_insert(index); + } + } + + fn capture(&mut self, index: usize, from: &Place, into: &Place) { + let from_id = from.identifier; + let into_id = into.identifier; + if !self.nodes.contains_key(&from_id) || !self.nodes.contains_key(&into_id) { + return; + } + self.nodes.get_mut(&from_id).unwrap().edges.push(Edge { + index, + node: into_id, + kind: EdgeKind::Capture, + }); + self.nodes + .get_mut(&into_id) + .unwrap() + .captures + .entry(from_id) + .or_insert(index); + } + + fn assign(&mut self, index: usize, from: &Place, into: &Place) { + let from_id = from.identifier; + let into_id = into.identifier; + if !self.nodes.contains_key(&from_id) || !self.nodes.contains_key(&into_id) { + return; + } + self.nodes.get_mut(&from_id).unwrap().edges.push(Edge { + index, + node: into_id, + kind: EdgeKind::Alias, + }); + self.nodes + .get_mut(&into_id) + .unwrap() + .aliases + .entry(from_id) + .or_insert(index); + } + + fn maybe_alias(&mut self, index: usize, from: &Place, into: &Place) { + let from_id = from.identifier; + let into_id = into.identifier; + if !self.nodes.contains_key(&from_id) || !self.nodes.contains_key(&into_id) { + return; + } + self.nodes.get_mut(&from_id).unwrap().edges.push(Edge { + index, + node: into_id, + kind: EdgeKind::MaybeAlias, + }); + self.nodes + .get_mut(&into_id) + .unwrap() + .maybe_aliases + .entry(from_id) + .or_insert(index); + } + + fn render(&self, index: usize, start: IdentifierId, env: &mut Environment) { + let mut seen = HashSet::new(); + let mut queue: Vec = vec![start]; + while let Some(current) = queue.pop() { + if !seen.insert(current) { + continue; + } + let node = match self.nodes.get(¤t) { + Some(n) => n, + None => continue, + }; + if node.transitive.is_some() || node.local.is_some() { + continue; + } + if let NodeValue::Function { function_id } = &node.value { + append_function_errors(env, *function_id); + } + for (&alias, &when) in &node.created_from { + if when >= index { + continue; + } + queue.push(alias); + } + for (&alias, &when) in &node.aliases { + if when >= index { + continue; + } + queue.push(alias); + } + for (&capture, &when) in &node.captures { + if when >= index { + continue; + } + queue.push(capture); + } + } + } + + fn mutate( + &mut self, + index: usize, + start: IdentifierId, + end: Option, // None for simulated mutations + transitive: bool, + start_kind: MutationKind, + loc: Option, + reason: Option, + env: &mut Environment, + should_record_errors: bool, + ) { + #[derive(Clone)] + struct QueueEntry { + place: IdentifierId, + transitive: bool, + direction: Direction, + kind: MutationKind, + } + #[derive(Clone, Copy, PartialEq)] + enum Direction { + Backwards, + Forwards, + } + + let mut seen: HashMap = HashMap::new(); + let mut queue: Vec = vec![QueueEntry { + place: start, + transitive, + direction: Direction::Backwards, + kind: start_kind, + }]; + + while let Some(entry) = queue.pop() { + let current = entry.place; + let previous_kind = seen.get(¤t).copied(); + if let Some(prev) = previous_kind { + if prev >= entry.kind { + continue; + } + } + seen.insert(current, entry.kind); + + let node = match self.nodes.get_mut(¤t) { + Some(n) => n, + None => continue, + }; + + if node.mutation_reason.is_none() { + node.mutation_reason = reason.clone(); + } + node.last_mutated = node.last_mutated.max(index); + + if let Some(end_val) = end { + let ident = &mut env.identifiers[node.id.0 as usize]; + ident.mutable_range.end = EvaluationOrder(ident.mutable_range.end.0.max(end_val.0)); + } + + if let NodeValue::Function { function_id } = &node.value { + if node.transitive.is_none() && node.local.is_none() { + if should_record_errors { + append_function_errors(env, *function_id); + } + } + } + + if entry.transitive { + match &node.transitive { + None => { + node.transitive = Some(MutationInfo { + kind: entry.kind, + loc, + }); + } + Some(existing) if existing.kind < entry.kind => { + node.transitive = Some(MutationInfo { + kind: entry.kind, + loc, + }); + } + _ => {} + } + } else { + match &node.local { + None => { + node.local = Some(MutationInfo { + kind: entry.kind, + loc, + }); + } + Some(existing) if existing.kind < entry.kind => { + node.local = Some(MutationInfo { + kind: entry.kind, + loc, + }); + } + _ => {} + } + } + + // Forward edges: Capture a -> b, Alias a -> b: mutate(a) => mutate(b) + // Collect edges to avoid borrow conflict + let edges: Vec = node.edges.clone(); + let node_value_kind = match &node.value { + NodeValue::Phi => "Phi", + _ => "Other", + }; + let node_aliases: Vec<(IdentifierId, usize)> = + node.aliases.iter().map(|(&k, &v)| (k, v)).collect(); + let node_maybe_aliases: Vec<(IdentifierId, usize)> = + node.maybe_aliases.iter().map(|(&k, &v)| (k, v)).collect(); + let node_captures: Vec<(IdentifierId, usize)> = + node.captures.iter().map(|(&k, &v)| (k, v)).collect(); + let node_created_from: Vec<(IdentifierId, usize)> = + node.created_from.iter().map(|(&k, &v)| (k, v)).collect(); + + for edge in &edges { + if edge.index >= index { + break; + } + queue.push(QueueEntry { + place: edge.node, + transitive: entry.transitive, + direction: Direction::Forwards, + // MaybeAlias edges downgrade to conditional mutation + kind: if edge.kind == EdgeKind::MaybeAlias { + MutationKind::Conditional + } else { + entry.kind + }, + }); + } + + for (alias, when) in &node_created_from { + if *when >= index { + continue; + } + queue.push(QueueEntry { + place: *alias, + transitive: true, + direction: Direction::Backwards, + kind: entry.kind, + }); + } + + if entry.direction == Direction::Backwards || node_value_kind != "Phi" { + // Backward alias edges + for (alias, when) in &node_aliases { + if *when >= index { + continue; + } + queue.push(QueueEntry { + place: *alias, + transitive: entry.transitive, + direction: Direction::Backwards, + kind: entry.kind, + }); + } + // MaybeAlias backward edges (downgrade to conditional) + for (alias, when) in &node_maybe_aliases { + if *when >= index { + continue; + } + queue.push(QueueEntry { + place: *alias, + transitive: entry.transitive, + direction: Direction::Backwards, + kind: MutationKind::Conditional, + }); + } + } + + // Only transitive mutations affect captures backward + if entry.transitive { + for (capture, when) in &node_captures { + if *when >= index { + continue; + } + queue.push(QueueEntry { + place: *capture, + transitive: entry.transitive, + direction: Direction::Backwards, + kind: entry.kind, + }); + } + } + } + } +} + +// ============================================================================= +// Helper: append function errors +// ============================================================================= + +fn append_function_errors(env: &mut Environment, function_id: FunctionId) { + let func = &env.functions[function_id.0 as usize]; + if let Some(ref effects) = func.aliasing_effects { + // Collect errors first to avoid borrow conflict + let errors: Vec<_> = effects + .iter() + .filter_map(|effect| match effect { + AliasingEffect::Impure { error, .. } + | AliasingEffect::MutateFrozen { error, .. } + | AliasingEffect::MutateGlobal { error, .. } => Some(error.clone()), + _ => None, + }) + .collect(); + for error in errors { + env.record_diagnostic(error); + } + } +} + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Infers mutable ranges for identifiers and populates Place effects. +/// +/// Returns the externally-visible effects of the function (mutations of +/// params/context-vars, aliasing between params/context-vars/return). +/// +/// Corresponds to TS `inferMutationAliasingRanges(fn, {isFunctionExpression})`. +pub fn infer_mutation_aliasing_ranges( + func: &mut HirFunction, + env: &mut Environment, + is_function_expression: bool, +) -> Result, CompilerDiagnostic> { + let mut function_effects: Vec = Vec::new(); + + // ========================================================================= + // Part 1: Build data flow graph and infer mutable ranges + // ========================================================================= + let mut state = AliasingState::new(); + + struct PendingPhiOperand { + from: Place, + into: Place, + index: usize, + } + let mut pending_phis: HashMap> = HashMap::new(); + + struct PendingMutation { + index: usize, + id: EvaluationOrder, + transitive: bool, + kind: MutationKind, + place: Place, + reason: Option, + } + let mut mutations: Vec = Vec::new(); + + struct PendingRender { + index: usize, + place: Place, + } + let mut renders: Vec = Vec::new(); + + let mut index: usize = 0; + + let should_record_errors = !is_function_expression && env.enable_validations(); + + // Create nodes for params, context vars, and return + for param in &func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p, + react_compiler_hir::ParamPattern::Spread(s) => &s.place, + }; + state.create(place, NodeValue::Object); + } + for ctx in &func.context { + state.create(ctx, NodeValue::Object); + } + state.create(&func.returns, NodeValue::Object); + + let mut seen_blocks: HashSet = HashSet::new(); + + // Collect block iteration data to avoid borrow conflicts + let block_order: Vec = func.body.blocks.keys().cloned().collect(); + + for &block_id in &block_order { + let block = &func.body.blocks[&block_id]; + + // Process phis + for phi in &block.phis { + state.create(&phi.place, NodeValue::Phi); + for (&pred, operand) in &phi.operands { + if !seen_blocks.contains(&pred) { + pending_phis + .entry(pred) + .or_insert_with(Vec::new) + .push(PendingPhiOperand { + from: operand.clone(), + into: phi.place.clone(), + index, + }); + index += 1; + } else { + state.assign(index, operand, &phi.place); + index += 1; + } + } + } + seen_blocks.insert(block_id); + + // Process instruction effects + let instr_ids: Vec<_> = block.instructions.clone(); + for instr_id in &instr_ids { + let instr = &func.instructions[instr_id.0 as usize]; + let instr_eval_order = instr.id; + let effects = match &instr.effects { + Some(e) => e.clone(), + None => continue, + }; + for effect in &effects { + match effect { + AliasingEffect::Create { into, .. } => { + state.create(into, NodeValue::Object); + } + AliasingEffect::CreateFunction { + into, function_id, .. + } => { + state.create( + into, + NodeValue::Function { + function_id: *function_id, + }, + ); + } + AliasingEffect::CreateFrom { from, into } => { + state.create_from(index, from, into); + index += 1; + } + AliasingEffect::Assign { from, into } => { + if !state.nodes.contains_key(&into.identifier) { + state.create(into, NodeValue::Object); + } + state.assign(index, from, into); + index += 1; + } + AliasingEffect::Alias { from, into } => { + state.assign(index, from, into); + index += 1; + } + AliasingEffect::MaybeAlias { from, into } => { + state.maybe_alias(index, from, into); + index += 1; + } + AliasingEffect::Capture { from, into } => { + state.capture(index, from, into); + index += 1; + } + AliasingEffect::MutateTransitive { value } + | AliasingEffect::MutateTransitiveConditionally { value } => { + let is_transitive_conditional = + matches!(effect, AliasingEffect::MutateTransitiveConditionally { .. }); + mutations.push(PendingMutation { + index, + id: instr_eval_order, + transitive: true, + kind: if is_transitive_conditional { + MutationKind::Conditional + } else { + MutationKind::Definite + }, + reason: None, + place: value.clone(), + }); + index += 1; + } + AliasingEffect::Mutate { value, reason } => { + mutations.push(PendingMutation { + index, + id: instr_eval_order, + transitive: false, + kind: MutationKind::Definite, + reason: reason.clone(), + place: value.clone(), + }); + index += 1; + } + AliasingEffect::MutateConditionally { value } => { + mutations.push(PendingMutation { + index, + id: instr_eval_order, + transitive: false, + kind: MutationKind::Conditional, + reason: None, + place: value.clone(), + }); + index += 1; + } + AliasingEffect::MutateFrozen { .. } + | AliasingEffect::MutateGlobal { .. } + | AliasingEffect::Impure { .. } => { + if should_record_errors { + match effect { + AliasingEffect::MutateFrozen { error, .. } + | AliasingEffect::MutateGlobal { error, .. } + | AliasingEffect::Impure { error, .. } => { + env.record_diagnostic(error.clone()); + } + _ => unreachable!(), + } + } + function_effects.push(effect.clone()); + } + AliasingEffect::Render { place } => { + renders.push(PendingRender { + index, + place: place.clone(), + }); + index += 1; + function_effects.push(effect.clone()); + } + // Other effects (Freeze, ImmutableCapture, Apply) are no-ops here + _ => {} + } + } + } + + // Process pending phis for this block + let block = &func.body.blocks[&block_id]; + if let Some(block_phis) = pending_phis.remove(&block_id) { + for pending in block_phis { + state.assign(pending.index, &pending.from, &pending.into); + } + } + + // Handle return terminal + let terminal = &block.terminal; + if let react_compiler_hir::Terminal::Return { value, .. } = terminal { + state.assign(index, value, &func.returns); + index += 1; + } + + // Handle terminal effects (MaybeThrow and Return) + let terminal_effects = match terminal { + react_compiler_hir::Terminal::MaybeThrow { effects, .. } + | react_compiler_hir::Terminal::Return { effects, .. } => effects.clone(), + _ => None, + }; + if let Some(effects) = terminal_effects { + for effect in &effects { + match effect { + AliasingEffect::Alias { from, into } => { + state.assign(index, from, into); + index += 1; + } + AliasingEffect::Freeze { .. } => { + // Expected for MaybeThrow terminals, skip + } + _ => { + // TS: CompilerError.invariant(effect.kind === 'Freeze', + // ...) We skip non-Alias, + // non-Freeze effects + } + } + } + } + } + + // Process mutations + for mutation in &mutations { + state.mutate( + mutation.index, + mutation.place.identifier, + Some(EvaluationOrder(mutation.id.0 + 1)), + mutation.transitive, + mutation.kind, + mutation.place.loc, + mutation.reason.clone(), + env, + should_record_errors, + ); + } + + // Process renders + for render in &renders { + if should_record_errors { + state.render(render.index, render.place.identifier, env); + } + } + + // Collect function effects for context vars and params + // NOTE: TS iterates [...fn.context, ...fn.params] — context first, then params + for ctx in &func.context { + collect_param_effects(&state, ctx, &mut function_effects); + } + for param in &func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p, + react_compiler_hir::ParamPattern::Spread(s) => &s.place, + }; + collect_param_effects(&state, place, &mut function_effects); + } + + // Set effect on mutated params/context vars + // We need to do this in a separate pass because we need to know which params + // were mutated before setting effects + let mut captured_params: HashSet = HashSet::new(); + for param in &func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p, + react_compiler_hir::ParamPattern::Spread(s) => &s.place, + }; + if let Some(node) = state.nodes.get(&place.identifier) { + if node.local.is_some() || node.transitive.is_some() { + captured_params.insert(place.identifier); + } + } + } + for ctx in &func.context { + if let Some(node) = state.nodes.get(&ctx.identifier) { + if node.local.is_some() || node.transitive.is_some() { + captured_params.insert(ctx.identifier); + } + } + } + + // Now mutate the effects on params/context in place + for param in &mut func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p, + react_compiler_hir::ParamPattern::Spread(s) => &mut s.place, + }; + if captured_params.contains(&place.identifier) { + place.effect = Effect::Capture; + } + } + for ctx in &mut func.context { + if captured_params.contains(&ctx.identifier) { + ctx.effect = Effect::Capture; + } + } + + // ========================================================================= + // Part 2: Add legacy operand-specific effects based on instruction effects + // and mutable ranges. Also fix up mutable range start values. + // ========================================================================= + // Part 2 loop + for &block_id in &block_order { + let block = &func.body.blocks[&block_id]; + + // Process phis + let phi_data: Vec<_> = block + .phis + .iter() + .map(|phi| { + let first_instr_id = block + .instructions + .first() + .map(|id| func.instructions[id.0 as usize].id) + .unwrap_or_else(|| block.terminal.evaluation_order()); + + let is_mutated_after_creation = env.identifiers[phi.place.identifier.0 as usize] + .mutable_range + .end + > first_instr_id; + + ( + phi.place.identifier, + phi.operands + .values() + .map(|o| o.identifier) + .collect::>(), + is_mutated_after_creation, + first_instr_id, + ) + }) + .collect(); + + for (phi_id, _operand_ids, is_mutated_after_creation, first_instr_id) in &phi_data { + // Set phi place effect to Store + // We need to find this phi in the block and set it + let block = func.body.blocks.get_mut(&block_id).unwrap(); + for phi in &mut block.phis { + if phi.place.identifier == *phi_id { + phi.place.effect = Effect::Store; + for operand in phi.operands.values_mut() { + operand.effect = if *is_mutated_after_creation { + Effect::Capture + } else { + Effect::Read + }; + } + break; + } + } + + if *is_mutated_after_creation { + let ident = &mut env.identifiers[phi_id.0 as usize]; + if ident.mutable_range.start == EvaluationOrder(0) { + ident.mutable_range.start = EvaluationOrder(first_instr_id.0.saturating_sub(1)); + } + } + } + + let block = &func.body.blocks[&block_id]; + let instr_ids: Vec<_> = block.instructions.clone(); + + for instr_id in &instr_ids { + let instr = &func.instructions[instr_id.0 as usize]; + let eval_order = instr.id; + + // Set lvalue effect to ConditionallyMutate and fix up mutable range + // This covers the top-level lvalue + let lvalue_id = instr.lvalue.identifier; + { + let ident = &mut env.identifiers[lvalue_id.0 as usize]; + if ident.mutable_range.start == EvaluationOrder(0) { + ident.mutable_range.start = eval_order; + } + if ident.mutable_range.end == EvaluationOrder(0) { + ident.mutable_range.end = + EvaluationOrder((eval_order.0 + 1).max(ident.mutable_range.end.0)); + } + } + func.instructions[instr_id.0 as usize].lvalue.effect = Effect::ConditionallyMutate; + + // Also handle value-level lvalues (DeclareLocal, StoreLocal, etc.) + let value_lvalue_ids: Vec = + each_instruction_value_lvalue(&func.instructions[instr_id.0 as usize].value) + .into_iter() + .map(|p| p.identifier) + .collect(); + for vlid in &value_lvalue_ids { + let ident = &mut env.identifiers[vlid.0 as usize]; + if ident.mutable_range.start == EvaluationOrder(0) { + ident.mutable_range.start = eval_order; + } + if ident.mutable_range.end == EvaluationOrder(0) { + ident.mutable_range.end = + EvaluationOrder((eval_order.0 + 1).max(ident.mutable_range.end.0)); + } + } + for_each_instruction_value_lvalue_mut( + &mut func.instructions[instr_id.0 as usize].value, + &mut |place| { + place.effect = Effect::ConditionallyMutate; + }, + ); + + // Set operand effects to Read + for_each_instruction_value_operand_mut( + &mut func.instructions[instr_id.0 as usize].value, + &mut |place| { + place.effect = Effect::Read; + }, + ); + + let instr = &func.instructions[instr_id.0 as usize]; + if instr.effects.is_none() { + continue; + } + + // Compute operand effects from instruction effects + let effects = instr.effects.as_ref().unwrap().clone(); + let mut operand_effects: HashMap = HashMap::new(); + + for effect in &effects { + match effect { + AliasingEffect::Assign { from, into, .. } + | AliasingEffect::Alias { from, into } + | AliasingEffect::Capture { from, into } + | AliasingEffect::CreateFrom { from, into } + | AliasingEffect::MaybeAlias { from, into } => { + let is_mutated_or_reassigned = env.identifiers[into.identifier.0 as usize] + .mutable_range + .end + > eval_order; + if is_mutated_or_reassigned { + operand_effects.insert(from.identifier, Effect::Capture); + operand_effects.insert(into.identifier, Effect::Store); + } else { + operand_effects.insert(from.identifier, Effect::Read); + operand_effects.insert(into.identifier, Effect::Store); + } + } + AliasingEffect::CreateFunction { .. } | AliasingEffect::Create { .. } => { + // no-op + } + AliasingEffect::Mutate { value, .. } => { + operand_effects.insert(value.identifier, Effect::Store); + } + AliasingEffect::Apply { .. } => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "[AnalyzeFunctions] Expected Apply effects to be replaced with more \ + precise effects", + None, + )); + } + AliasingEffect::MutateTransitive { value, .. } + | AliasingEffect::MutateConditionally { value } + | AliasingEffect::MutateTransitiveConditionally { value } => { + operand_effects.insert(value.identifier, Effect::ConditionallyMutate); + } + AliasingEffect::Freeze { value, .. } => { + operand_effects.insert(value.identifier, Effect::Freeze); + } + AliasingEffect::ImmutableCapture { .. } => { + // no-op, Read is the default + } + AliasingEffect::Impure { .. } + | AliasingEffect::Render { .. } + | AliasingEffect::MutateFrozen { .. } + | AliasingEffect::MutateGlobal { .. } => { + // no-op + } + } + } + + // Apply operand effects to top-level lvalue + let instr = &mut func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + if let Some(&effect) = operand_effects.get(&lvalue_id) { + instr.lvalue.effect = effect; + } + // Apply operand effects to value-level lvalues + for_each_instruction_value_lvalue_mut(&mut instr.value, &mut |place| { + if let Some(&effect) = operand_effects.get(&place.identifier) { + place.effect = effect; + } + }); + + // Apply operand effects to value operands and fix up mutable ranges + { + let mut apply = |place: &mut Place| { + // Fix up mutable range start + let ident = &env.identifiers[place.identifier.0 as usize]; + if ident.mutable_range.end > eval_order + && ident.mutable_range.start == EvaluationOrder(0) + { + env.identifiers[place.identifier.0 as usize] + .mutable_range + .start = eval_order; + } + // Apply effect + if let Some(&effect) = operand_effects.get(&place.identifier) { + place.effect = effect; + } + }; + for_each_instruction_value_operand_mut(&mut instr.value, &mut apply); + + // FunctionExpression/ObjectMethod context variables are operands that + // require env access (they live in env.functions[func_id].context). + if let InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } = &instr.value + { + let func_id = lowered_func.func; + let ctx_ids: Vec = env.functions[func_id.0 as usize] + .context + .iter() + .map(|c| c.identifier) + .collect(); + for ctx_id in &ctx_ids { + let ident = &env.identifiers[ctx_id.0 as usize]; + if ident.mutable_range.end > eval_order + && ident.mutable_range.start == EvaluationOrder(0) + { + env.identifiers[ctx_id.0 as usize].mutable_range.start = eval_order; + } + let effect = operand_effects.get(ctx_id).copied().unwrap_or(Effect::Read); + let inner_func = &mut env.functions[func_id.0 as usize]; + for ctx_place in &mut inner_func.context { + if ctx_place.identifier == *ctx_id { + ctx_place.effect = effect; + } + } + } + } + } + + // Handle StoreContext case: extend rvalue range if needed + let instr = &func.instructions[instr_id.0 as usize]; + if let InstructionValue::StoreContext { value, .. } = &instr.value { + let val_id = value.identifier; + let val_range_end = env.identifiers[val_id.0 as usize].mutable_range.end; + if val_range_end <= eval_order { + env.identifiers[val_id.0 as usize].mutable_range.end = + EvaluationOrder(eval_order.0 + 1); + } + } + } + + // Set terminal operand effects + let block = func.body.blocks.get_mut(&block_id).unwrap(); + match &mut block.terminal { + react_compiler_hir::Terminal::Return { value, .. } => { + value.effect = if is_function_expression { + Effect::Read + } else { + Effect::Freeze + }; + } + terminal => { + for_each_terminal_operand_mut(terminal, &mut |place| { + place.effect = Effect::Read; + }); + } + } + } + + // ========================================================================= + // Part 3: Finish populating the externally visible effects + // ========================================================================= + let returns_id = func.returns.identifier; + let returns_type_id = env.identifiers[returns_id.0 as usize].type_; + let returns_type = &env.types[returns_type_id.0 as usize]; + let return_value_kind = if is_primitive_type(returns_type) { + ValueKind::Primitive + } else if is_jsx_type(returns_type) { + ValueKind::Frozen + } else { + ValueKind::Mutable + }; + + function_effects.push(AliasingEffect::Create { + into: func.returns.clone(), + value: return_value_kind, + reason: ValueReason::KnownReturnSignature, + }); + + // Determine precise data-flow effects by simulating transitive mutations + let mut tracked: Vec = Vec::new(); + for param in &func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p.clone(), + react_compiler_hir::ParamPattern::Spread(s) => s.place.clone(), + }; + tracked.push(place); + } + for ctx in &func.context { + tracked.push(ctx.clone()); + } + tracked.push(func.returns.clone()); + + let returns_identifier_id = func.returns.identifier; + + for i in 0..tracked.len() { + let into = tracked[i].clone(); + let mutation_index = index; + index += 1; + + state.mutate( + mutation_index, + into.identifier, + None, // simulated mutation + true, + MutationKind::Conditional, + into.loc, + None, + env, + false, // never record errors for simulated mutations + ); + + for j in 0..tracked.len() { + let from = &tracked[j]; + if from.identifier == into.identifier || from.identifier == returns_identifier_id { + continue; + } + + let from_node = state.nodes.get(&from.identifier); + assert!( + from_node.is_some(), + "Expected a node to exist for all parameters and context variables" + ); + let from_node = from_node.unwrap(); + + if from_node.last_mutated == mutation_index { + if into.identifier == returns_identifier_id { + function_effects.push(AliasingEffect::Alias { + from: from.clone(), + into: into.clone(), + }); + } else { + function_effects.push(AliasingEffect::Capture { + from: from.clone(), + into: into.clone(), + }); + } + } + } + } + + Ok(function_effects) +} + +// ============================================================================= +// Helper: collect param/context mutation effects +// ============================================================================= + +fn collect_param_effects( + state: &AliasingState, + place: &Place, + function_effects: &mut Vec, +) { + let node = match state.nodes.get(&place.identifier) { + Some(n) => n, + None => return, + }; + + if let Some(ref local) = node.local { + match local.kind { + MutationKind::Conditional => { + function_effects.push(AliasingEffect::MutateConditionally { + value: Place { + loc: local.loc, + ..place.clone() + }, + }); + } + MutationKind::Definite => { + function_effects.push(AliasingEffect::Mutate { + value: Place { + loc: local.loc, + ..place.clone() + }, + reason: node.mutation_reason.clone(), + }); + } + MutationKind::None => {} + } + } + + if let Some(ref transitive) = node.transitive { + match transitive.kind { + MutationKind::Conditional => { + function_effects.push(AliasingEffect::MutateTransitiveConditionally { + value: Place { + loc: transitive.loc, + ..place.clone() + }, + }); + } + MutationKind::Definite => { + function_effects.push(AliasingEffect::MutateTransitive { + value: Place { + loc: transitive.loc, + ..place.clone() + }, + }); + } + MutationKind::None => {} + } + } +} diff --git a/crates/react_compiler_inference/src/infer_reactive_places.rs b/crates/react_compiler_inference/src/infer_reactive_places.rs new file mode 100644 index 000000000000..41153864d095 --- /dev/null +++ b/crates/react_compiler_inference/src/infer_reactive_places.rs @@ -0,0 +1,783 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Infers which `Place`s are reactive. +//! +//! Ported from TypeScript `src/Inference/InferReactivePlaces.ts`. +//! +//! A place is reactive if it derives from any source of reactivity: +//! 1. Props (component parameters may change between renders) +//! 2. Hooks (can access state or context) +//! 3. `use` operator (can access context) +//! 4. Mutation with reactive operands +//! 5. Conditional assignment based on reactive control flow + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + dominator::post_dominator_frontier, environment::Environment, object_shape::HookKind, visitors, + BlockId, Effect, FunctionId, HirFunction, IdentifierId, InstructionValue, ParamPattern, + Terminal, Type, +}; +use react_compiler_utils::DisjointSet; + +use crate::infer_reactive_scope_variables::find_disjoint_mutable_values; + +// ============================================================================= +// Public API +// ============================================================================= + +/// Infer which places in a function are reactive. +/// +/// Corresponds to TS `inferReactivePlaces(fn: HIRFunction): void`. +pub fn infer_reactive_places( + func: &mut HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + let mut aliased_identifiers = find_disjoint_mutable_values(func, env); + let mut reactive_map = ReactivityMap::new(&mut aliased_identifiers); + let mut stable_sidemap = StableSidemap::new(); + + // Mark all function parameters as reactive + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + reactive_map.mark_reactive(place.identifier); + } + + // Compute control dominators + let post_dominators = react_compiler_hir::dominator::compute_post_dominator_tree( + func, + env.next_block_id().0, + false, + )?; + + // Collect block IDs for iteration + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + // Track phi operand reactive flags during fixpoint. + // In TS, isReactive() sets place.reactive as a side effect. But when a phi + // is already reactive, the TS `continue`s and skips operand processing. + // We track which phi operand Places should be marked reactive. + // Key: (block_id, phi_idx, operand_idx), Value: should be reactive + let mut phi_operand_reactive: HashMap<(BlockId, usize, usize), bool> = HashMap::new(); + + // Fixpoint iteration — compute reactive set + loop { + for block_id in &block_ids { + let block = func.body.blocks.get(block_id).unwrap(); + let has_reactive_control = + is_reactive_controlled_block(block.id, func, &post_dominators, &mut reactive_map); + + // Process phi nodes + let block = func.body.blocks.get(block_id).unwrap(); + for (phi_idx, phi) in block.phis.iter().enumerate() { + if reactive_map.is_reactive(phi.place.identifier) { + // TS does `continue` here — skips operand isReactive calls. + // phi operand reactive flags stay as they were from last visit. + continue; + } + let mut is_phi_reactive = false; + for (op_idx, (_pred, operand)) in phi.operands.iter().enumerate() { + let op_reactive = reactive_map.is_reactive(operand.identifier); + // Record the reactive state for this operand at this point + phi_operand_reactive.insert((*block_id, phi_idx, op_idx), op_reactive); + if op_reactive { + is_phi_reactive = true; + break; // TS breaks here — remaining operands NOT + // visited + } + } + if is_phi_reactive { + reactive_map.mark_reactive(phi.place.identifier); + } else { + for (pred, _operand) in &phi.operands { + if is_reactive_controlled_block( + *pred, + func, + &post_dominators, + &mut reactive_map, + ) { + reactive_map.mark_reactive(phi.place.identifier); + break; + } + } + } + } + + // Process instructions + let block = func.body.blocks.get(block_id).unwrap(); + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + + // Handle stable identifier sources + stable_sidemap.handle_instruction(instr, env); + + let value = &instr.value; + + // Check if any operand is reactive + let mut has_reactive_input = false; + let operands: Vec = + visitors::each_instruction_value_operand(value, env) + .into_iter() + .map(|p| p.identifier) + .collect(); + for &op_id in &operands { + let reactive = reactive_map.is_reactive(op_id); + has_reactive_input = has_reactive_input || reactive; + } + + // Hooks and `use` operator are sources of reactivity + match value { + InstructionValue::CallExpression { callee, .. } => { + let callee_ty = &env.types + [env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if get_hook_kind_for_type(env, callee_ty)?.is_some() + || is_use_operator_type(callee_ty) + { + has_reactive_input = true; + } + } + InstructionValue::MethodCall { property, .. } => { + let property_ty = &env.types + [env.identifiers[property.identifier.0 as usize].type_.0 as usize]; + if get_hook_kind_for_type(env, property_ty)?.is_some() + || is_use_operator_type(property_ty) + { + has_reactive_input = true; + } + } + _ => {} + } + + if has_reactive_input { + // Mark lvalues reactive (unless stable) + let lvalue_ids: Vec = visitors::each_instruction_lvalue(instr) + .into_iter() + .map(|p| p.identifier) + .collect(); + for lvalue_id in lvalue_ids { + if stable_sidemap.is_stable(lvalue_id) { + continue; + } + reactive_map.mark_reactive(lvalue_id); + } + } + + if has_reactive_input || has_reactive_control { + // Mark mutable operands reactive + let operand_places = visitors::each_instruction_value_operand(value, env); + for op_place in &operand_places { + match op_place.effect { + Effect::Capture + | Effect::Store + | Effect::ConditionallyMutate + | Effect::ConditionallyMutateIterator + | Effect::Mutate => { + let op_range = + &env.identifiers[op_place.identifier.0 as usize].mutable_range; + if op_range.contains(instr.id) { + reactive_map.mark_reactive(op_place.identifier); + } + } + Effect::Freeze | Effect::Read => { + // no-op + } + Effect::Unknown => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + &format!("Unexpected unknown effect at {:?}", op_place.loc), + None, + )); + } + } + } + } + } + + // Process terminal operands (just to mark them reactive for output) + for op in visitors::each_terminal_operand(&block.terminal) { + reactive_map.is_reactive(op.identifier); + } + } + + if !reactive_map.snapshot() { + break; + } + } + + // Propagate reactivity to inner functions (read-only phase, just queries + // reactive_map) + propagate_reactivity_to_inner_functions_outer(func, env, &mut reactive_map); + + // Now apply reactive flags by replaying the traversal pattern. + apply_reactive_flags_replay( + func, + env, + &mut reactive_map, + &mut stable_sidemap, + &phi_operand_reactive, + ); + + Ok(()) +} + +// ============================================================================= +// ReactivityMap +// ============================================================================= + +struct ReactivityMap<'a> { + has_changes: bool, + reactive: HashSet, + aliased_identifiers: &'a mut DisjointSet, +} + +impl<'a> ReactivityMap<'a> { + fn new(aliased_identifiers: &'a mut DisjointSet) -> Self { + ReactivityMap { + has_changes: false, + reactive: HashSet::new(), + aliased_identifiers, + } + } + + fn is_reactive(&mut self, id: IdentifierId) -> bool { + let canonical = self.aliased_identifiers.find_opt(id).unwrap_or(id); + self.reactive.contains(&canonical) + } + + fn mark_reactive(&mut self, id: IdentifierId) { + let canonical = self.aliased_identifiers.find_opt(id).unwrap_or(id); + if self.reactive.insert(canonical) { + self.has_changes = true; + } + } + + /// Reset change tracking, returns true if there were changes. + fn snapshot(&mut self) -> bool { + let had_changes = self.has_changes; + self.has_changes = false; + had_changes + } +} + +// ============================================================================= +// StableSidemap +// ============================================================================= + +struct StableSidemap { + map: HashMap, +} + +impl StableSidemap { + fn new() -> Self { + StableSidemap { + map: HashMap::new(), + } + } + + fn handle_instruction(&mut self, instr: &react_compiler_hir::Instruction, env: &Environment) { + let lvalue_id = instr.lvalue.identifier; + let value = &instr.value; + + match value { + InstructionValue::CallExpression { callee, .. } => { + let callee_ty = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if evaluates_to_stable_type_or_container(env, callee_ty) { + let lvalue_ty = + &env.types[env.identifiers[lvalue_id.0 as usize].type_.0 as usize]; + if is_stable_type(lvalue_ty) { + self.map.insert(lvalue_id, true); + } else { + self.map.insert(lvalue_id, false); + } + } + } + InstructionValue::MethodCall { property, .. } => { + let property_ty = + &env.types[env.identifiers[property.identifier.0 as usize].type_.0 as usize]; + if evaluates_to_stable_type_or_container(env, property_ty) { + let lvalue_ty = + &env.types[env.identifiers[lvalue_id.0 as usize].type_.0 as usize]; + if is_stable_type(lvalue_ty) { + self.map.insert(lvalue_id, true); + } else { + self.map.insert(lvalue_id, false); + } + } + } + InstructionValue::PropertyLoad { object, .. } => { + let source_id = object.identifier; + if self.map.contains_key(&source_id) { + let lvalue_ty = + &env.types[env.identifiers[lvalue_id.0 as usize].type_.0 as usize]; + if is_stable_type_container(lvalue_ty) { + self.map.insert(lvalue_id, false); + } else if is_stable_type(lvalue_ty) { + self.map.insert(lvalue_id, true); + } + } + } + InstructionValue::Destructure { value: val, .. } => { + let source_id = val.identifier; + if self.map.contains_key(&source_id) { + let lvalue_ids: Vec = visitors::each_instruction_lvalue(instr) + .into_iter() + .map(|p| p.identifier) + .collect(); + for lid in lvalue_ids { + let lid_ty = &env.types[env.identifiers[lid.0 as usize].type_.0 as usize]; + if is_stable_type_container(lid_ty) { + self.map.insert(lid, false); + } else if is_stable_type(lid_ty) { + self.map.insert(lid, true); + } + } + } + } + InstructionValue::StoreLocal { + lvalue, value: val, .. + } => { + if let Some(&entry) = self.map.get(&val.identifier) { + self.map.insert(lvalue_id, entry); + self.map.insert(lvalue.place.identifier, entry); + } + } + InstructionValue::LoadLocal { place, .. } => { + if let Some(&entry) = self.map.get(&place.identifier) { + self.map.insert(lvalue_id, entry); + } + } + _ => {} + } + } + + fn is_stable(&self, id: IdentifierId) -> bool { + self.map.get(&id).copied().unwrap_or(false) + } +} + +// ============================================================================= +// Control dominators (ported from ControlDominators.ts) +// ============================================================================= + +fn is_reactive_controlled_block( + block_id: BlockId, + func: &HirFunction, + post_dominators: &react_compiler_hir::dominator::PostDominator, + reactive_map: &mut ReactivityMap, +) -> bool { + let frontier = post_dominator_frontier(func, post_dominators, block_id); + for frontier_block_id in &frontier { + let control_block = func.body.blocks.get(frontier_block_id).unwrap(); + match &control_block.terminal { + Terminal::If { test, .. } | Terminal::Branch { test, .. } => { + if reactive_map.is_reactive(test.identifier) { + return true; + } + } + Terminal::Switch { test, cases, .. } => { + if reactive_map.is_reactive(test.identifier) { + return true; + } + for case in cases { + if let Some(ref case_test) = case.test { + if reactive_map.is_reactive(case_test.identifier) { + return true; + } + } + } + } + _ => {} + } + } + false +} + +// ============================================================================= +// Type helpers (ported from HIR.ts) +// ============================================================================= + +use react_compiler_hir::is_use_operator_type; + +fn get_hook_kind_for_type<'a>( + env: &'a Environment, + ty: &Type, +) -> Result, CompilerDiagnostic> { + env.get_hook_kind_for_type(ty) +} + +fn is_stable_type(ty: &Type) -> bool { + match ty { + Type::Function { + shape_id: Some(id), .. + } => { + matches!( + id.as_str(), + "BuiltInSetState" + | "BuiltInSetActionState" + | "BuiltInDispatch" + | "BuiltInStartTransition" + | "BuiltInSetOptimistic" + ) + } + Type::Object { shape_id: Some(id) } => { + matches!(id.as_str(), "BuiltInUseRefId") + } + _ => false, + } +} + +fn is_stable_type_container(ty: &Type) -> bool { + match ty { + Type::Object { shape_id: Some(id) } => { + matches!( + id.as_str(), + "BuiltInUseState" + | "BuiltInUseActionState" + | "BuiltInUseReducer" + | "BuiltInUseOptimistic" + | "BuiltInUseTransition" + ) + } + _ => false, + } +} + +fn evaluates_to_stable_type_or_container(env: &Environment, callee_ty: &Type) -> bool { + if let Some(hook_kind) = get_hook_kind_for_type(env, callee_ty).ok().flatten() { + matches!( + hook_kind, + HookKind::UseState + | HookKind::UseReducer + | HookKind::UseActionState + | HookKind::UseRef + | HookKind::UseTransition + | HookKind::UseOptimistic + ) + } else { + false + } +} + +// ============================================================================= +// Propagate reactivity to inner functions +// ============================================================================= + +fn propagate_reactivity_to_inner_functions_outer( + func: &HirFunction, + env: &Environment, + reactive_map: &mut ReactivityMap, +) { + for (_block_id, block) in &func.body.blocks { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + propagate_reactivity_to_inner_functions_inner( + lowered_func.func, + env, + reactive_map, + ); + } + _ => {} + } + } + } +} + +fn propagate_reactivity_to_inner_functions_inner( + func_id: FunctionId, + env: &Environment, + reactive_map: &mut ReactivityMap, +) { + let inner_func = &env.functions[func_id.0 as usize]; + + for (_block_id, block) in &inner_func.body.blocks { + for instr_id in &block.instructions { + let instr = &inner_func.instructions[instr_id.0 as usize]; + + for op in visitors::each_instruction_value_operand(&instr.value, env) { + reactive_map.is_reactive(op.identifier); + } + + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + propagate_reactivity_to_inner_functions_inner( + lowered_func.func, + env, + reactive_map, + ); + } + _ => {} + } + } + + for op in visitors::each_terminal_operand(&block.terminal) { + reactive_map.is_reactive(op.identifier); + } + } +} + +// ============================================================================= +// Apply reactive flags to the HIR (replay pass) +// ============================================================================= + +fn apply_reactive_flags_replay( + func: &mut HirFunction, + env: &mut Environment, + reactive_map: &mut ReactivityMap, + stable_sidemap: &mut StableSidemap, + phi_operand_reactive: &HashMap<(BlockId, usize, usize), bool>, +) { + let reactive_ids = build_reactive_id_set(reactive_map); + + // 1. Mark params + for param in &mut func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &mut s.place, + }; + place.reactive = true; + } + + // 2. Walk blocks + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + for block_id in &block_ids { + let block = func.body.blocks.get(block_id).unwrap(); + + // 2a. Phi nodes + let phi_count = block.phis.len(); + for phi_idx in 0..phi_count { + let block = func.body.blocks.get_mut(block_id).unwrap(); + let phi = &mut block.phis[phi_idx]; + + if reactive_ids.contains(&phi.place.identifier) { + phi.place.reactive = true; + } + + for (op_idx, (_pred, operand)) in phi.operands.iter_mut().enumerate() { + if let Some(&is_reactive) = phi_operand_reactive.get(&(*block_id, phi_idx, op_idx)) + { + if is_reactive { + operand.reactive = true; + } + } + } + } + + // 2b. Instructions + let block = func.body.blocks.get(block_id).unwrap(); + let instr_ids: Vec = block.instructions.clone(); + + for instr_id in &instr_ids { + let instr = &func.instructions[instr_id.0 as usize]; + + // Compute hasReactiveInput by checking value operands + let value_operand_ids: Vec = + visitors::each_instruction_value_operand(&instr.value, env) + .into_iter() + .map(|p| p.identifier) + .collect(); + let mut has_reactive_input = false; + for &op_id in &value_operand_ids { + if reactive_ids.contains(&op_id) { + has_reactive_input = true; + } + } + + // Check hooks/use + match &instr.value { + InstructionValue::CallExpression { callee, .. } => { + let callee_ty = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if get_hook_kind_for_type(env, callee_ty) + .ok() + .flatten() + .is_some() + || is_use_operator_type(callee_ty) + { + has_reactive_input = true; + } + } + InstructionValue::MethodCall { property, .. } => { + let property_ty = &env.types + [env.identifiers[property.identifier.0 as usize].type_.0 as usize]; + if get_hook_kind_for_type(env, property_ty) + .ok() + .flatten() + .is_some() + || is_use_operator_type(property_ty) + { + has_reactive_input = true; + } + } + _ => {} + } + + // Value operands: set reactive flag using canonical visitor + let instr = &mut func.instructions[instr_id.0 as usize]; + visitors::for_each_instruction_value_operand_mut(&mut instr.value, &mut |place| { + if reactive_ids.contains(&place.identifier) { + place.reactive = true; + } + }); + // FunctionExpression/ObjectMethod context variables require env access + if let InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } = &mut instr.value + { + let inner_func = &mut env.functions[lowered_func.func.0 as usize]; + for ctx in &mut inner_func.context { + if reactive_ids.contains(&ctx.identifier) { + ctx.reactive = true; + } + } + } + + // Lvalues: markReactive is called only when hasReactiveInput + if has_reactive_input { + let lvalue_id = instr.lvalue.identifier; + if !stable_sidemap.is_stable(lvalue_id) && reactive_ids.contains(&lvalue_id) { + instr.lvalue.reactive = true; + } + // Handle value lvalues — includes DeclareContext/StoreContext which + // for_each_instruction_lvalue_mut skips, so we use a direct match. + match &mut instr.value { + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } => { + let id = lvalue.place.identifier; + if !stable_sidemap.is_stable(id) && reactive_ids.contains(&id) { + lvalue.place.reactive = true; + } + } + InstructionValue::Destructure { lvalue, .. } => { + visitors::for_each_pattern_operand_mut(&mut lvalue.pattern, &mut |place| { + if !stable_sidemap.is_stable(place.identifier) + && reactive_ids.contains(&place.identifier) + { + place.reactive = true; + } + }); + } + InstructionValue::PrefixUpdate { lvalue, .. } + | InstructionValue::PostfixUpdate { lvalue, .. } => { + let id = lvalue.identifier; + if !stable_sidemap.is_stable(id) && reactive_ids.contains(&id) { + lvalue.reactive = true; + } + } + _ => {} + } + } + } + + // 2c. Terminal operands + let block = func.body.blocks.get_mut(block_id).unwrap(); + visitors::for_each_terminal_operand_mut(&mut block.terminal, &mut |place| { + if reactive_ids.contains(&place.identifier) { + place.reactive = true; + } + }); + } + + // 3. Apply to inner functions + apply_reactive_flags_to_inner_functions(func, env, &reactive_ids); +} + +fn build_reactive_id_set(reactive_map: &mut ReactivityMap) -> HashSet { + let mut result = HashSet::new(); + for &id in &reactive_map.reactive { + result.insert(id); + } + let reactive = &reactive_map.reactive; + reactive_map.aliased_identifiers.for_each(|id, canonical| { + if reactive.contains(&canonical) { + result.insert(id); + } + }); + result +} + +fn apply_reactive_flags_to_inner_functions( + func: &HirFunction, + env: &mut Environment, + reactive_ids: &HashSet, +) { + for (_block_id, block) in &func.body.blocks { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + apply_reactive_flags_to_inner_func(lowered_func.func, env, reactive_ids); + } + _ => {} + } + } + } +} + +fn apply_reactive_flags_to_inner_func( + func_id: FunctionId, + env: &mut Environment, + reactive_ids: &HashSet, +) { + // Collect nested function IDs first to avoid borrow issues + let nested_func_ids: Vec = { + let func = &env.functions[func_id.0 as usize]; + let mut ids = Vec::new(); + for (_block_id, block) in &func.body.blocks { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + ids.push(lowered_func.func); + } + _ => {} + } + } + } + ids + }; + + // Apply reactive flags using canonical visitors + let inner_func = &mut env.functions[func_id.0 as usize]; + for (_block_id, block) in &mut inner_func.body.blocks { + for instr_id in &block.instructions { + let instr = &mut inner_func.instructions[instr_id.0 as usize]; + visitors::for_each_instruction_value_operand_mut(&mut instr.value, &mut |place| { + if reactive_ids.contains(&place.identifier) { + place.reactive = true; + } + }); + } + visitors::for_each_terminal_operand_mut(&mut block.terminal, &mut |place| { + if reactive_ids.contains(&place.identifier) { + place.reactive = true; + } + }); + } + + // Recurse into nested functions, and set reactive on their context variables + for nested_id in nested_func_ids { + let nested_func = &mut env.functions[nested_id.0 as usize]; + for ctx in &mut nested_func.context { + if reactive_ids.contains(&ctx.identifier) { + ctx.reactive = true; + } + } + apply_reactive_flags_to_inner_func(nested_id, env, reactive_ids); + } +} diff --git a/crates/react_compiler_inference/src/infer_reactive_scope_variables.rs b/crates/react_compiler_inference/src/infer_reactive_scope_variables.rs new file mode 100644 index 000000000000..ea47fa8c32a3 --- /dev/null +++ b/crates/react_compiler_inference/src/infer_reactive_scope_variables.rs @@ -0,0 +1,389 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Infers which variables belong to reactive scopes. +//! +//! Ported from TypeScript `src/ReactiveScopes/InferReactiveScopeVariables.ts`. +//! +//! This is the 1st of 4 passes that determine how to break a function into +//! discrete reactive scopes (independently memoizable units of code): +//! 1. InferReactiveScopeVariables (this pass, on HIR) determines operands that +//! mutate together and assigns them a unique reactive scope. +//! 2. AlignReactiveScopesToBlockScopes aligns reactive scopes to block scopes. +//! 3. MergeOverlappingReactiveScopes ensures scopes do not overlap. +//! 4. BuildReactiveBlocks groups the statements for each scope. + +use std::collections::HashMap; + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, visitors, DeclarationId, EvaluationOrder, HirFunction, IdentifierId, + InstructionValue, Pattern, Position, SourceLocation, +}; +use react_compiler_utils::DisjointSet; + +// ============================================================================= +// Public API +// ============================================================================= + +/// Infer reactive scope variables for a function. +/// +/// For each mutable variable, infers a reactive scope which will construct that +/// variable. Variables that co-mutate are assigned to the same reactive scope. +/// +/// Corresponds to TS `inferReactiveScopeVariables(fn: HIRFunction): void`. +pub fn infer_reactive_scope_variables( + func: &mut HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + // Phase 1: find disjoint sets of co-mutating identifiers + let mut scope_identifiers = find_disjoint_mutable_values(func, env); + + // Phase 2: assign scopes + // Maps each group root identifier to the ScopeId assigned to that group. + let mut scopes: HashMap = HashMap::new(); + + scope_identifiers.for_each(|identifier_id, group_id| { + let ident_range = env.identifiers[identifier_id.0 as usize] + .mutable_range + .clone(); + let ident_loc = env.identifiers[identifier_id.0 as usize].loc; + + let state = scopes.entry(group_id).or_insert_with(|| { + let scope_id = env.next_scope_id(); + // Initialize scope range from the first member + let scope = &mut env.scopes[scope_id.0 as usize]; + scope.range = ident_range.clone(); + ScopeState { + scope_id, + loc: ident_loc, + } + }); + + // Update scope range + let scope = &mut env.scopes[state.scope_id.0 as usize]; + + // If this is not the first identifier (scope was already created), merge ranges + if scope.range.start != ident_range.start || scope.range.end != ident_range.end { + if scope.range.start == EvaluationOrder(0) { + scope.range.start = ident_range.start; + } else if ident_range.start != EvaluationOrder(0) { + scope.range.start = EvaluationOrder(scope.range.start.0.min(ident_range.start.0)); + } + scope.range.end = EvaluationOrder(scope.range.end.0.max(ident_range.end.0)); + } + + // Merge location + state.loc = merge_location(state.loc, ident_loc); + + // Assign the scope to this identifier + let scope_id = state.scope_id; + env.identifiers[identifier_id.0 as usize].scope = Some(scope_id); + }); + + // Set loc on each scope + for (_group_id, state) in &scopes { + env.scopes[state.scope_id.0 as usize].loc = state.loc; + } + + // Update each identifier's mutable_range to match its scope's range + for (&_identifier_id, state) in &scopes { + let scope_range = env.scopes[state.scope_id.0 as usize].range.clone(); + // Find all identifiers with this scope and update their mutable_range + // We iterate through all identifiers and check their scope + for ident in &mut env.identifiers { + if ident.scope == Some(state.scope_id) { + ident.mutable_range = scope_range.clone(); + } + } + } + + // Validate scope ranges + let mut max_instruction = EvaluationOrder(0); + for (_block_id, block) in &func.body.blocks { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + max_instruction = EvaluationOrder(max_instruction.0.max(instr.id.0)); + } + max_instruction = + EvaluationOrder(max_instruction.0.max(block.terminal.evaluation_order().0)); + } + + for (_group_id, state) in &scopes { + let scope = &env.scopes[state.scope_id.0 as usize]; + if scope.range.start == EvaluationOrder(0) + || scope.range.end == EvaluationOrder(0) + || max_instruction == EvaluationOrder(0) + || scope.range.end.0 > max_instruction.0 + 1 + { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + &format!( + "Invalid mutable range for scope: Scope @{} has range [{}:{}] but the valid \ + range is [1:{}]", + scope.id.0, + scope.range.start.0, + scope.range.end.0, + max_instruction.0 + 1, + ), + None, + )); + } + } + + Ok(()) +} + +struct ScopeState { + scope_id: react_compiler_hir::ScopeId, + loc: Option, +} + +/// Merge two source locations, preferring non-None values. +/// Corresponds to TS `mergeLocation`. +fn merge_location(l: Option, r: Option) -> Option { + match (l, r) { + (None, r) => r, + (l, None) => l, + (Some(l), Some(r)) => Some(SourceLocation { + start: Position { + line: l.start.line.min(r.start.line), + column: l.start.column.min(r.start.column), + index: match (l.start.index, r.start.index) { + (Some(a), Some(b)) => Some(a.min(b)), + (a, b) => a.or(b), + }, + }, + end: Position { + line: l.end.line.max(r.end.line), + column: l.end.column.max(r.end.column), + index: match (l.end.index, r.end.index) { + (Some(a), Some(b)) => Some(a.max(b)), + (a, b) => a.or(b), + }, + }, + }), + } +} + +// ============================================================================= +// is_mutable / in_range helpers +// ============================================================================= + +// ============================================================================= +// may_allocate +// ============================================================================= + +/// Check if an instruction may allocate. Corresponds to TS `mayAllocate`. +fn may_allocate(value: &InstructionValue, lvalue_type_is_primitive: bool) -> bool { + match value { + InstructionValue::Destructure { lvalue, .. } => { + visitors::does_pattern_contain_spread_element(&lvalue.pattern) + } + InstructionValue::PostfixUpdate { .. } + | InstructionValue::PrefixUpdate { .. } + | InstructionValue::Await { .. } + | InstructionValue::DeclareLocal { .. } + | InstructionValue::DeclareContext { .. } + | InstructionValue::StoreLocal { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::TypeCastExpression { .. } + | InstructionValue::LoadLocal { .. } + | InstructionValue::LoadContext { .. } + | InstructionValue::StoreContext { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::ComputedLoad { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::GetIterator { .. } + | InstructionValue::IteratorNext { .. } + | InstructionValue::NextPropertyOf { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::StartMemoize { .. } + | InstructionValue::FinishMemoize { .. } + | InstructionValue::UnaryExpression { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::PropertyLoad { .. } + | InstructionValue::StoreGlobal { .. } => false, + + InstructionValue::TaggedTemplateExpression { .. } + | InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } => !lvalue_type_is_primitive, + + InstructionValue::RegExpLiteral { .. } + | InstructionValue::PropertyStore { .. } + | InstructionValue::ComputedStore { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::JsxExpression { .. } + | InstructionValue::JsxFragment { .. } + | InstructionValue::NewExpression { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::UnsupportedNode { .. } + | InstructionValue::ObjectMethod { .. } + | InstructionValue::FunctionExpression { .. } => true, + } +} + +// ============================================================================= +// Pattern helpers +// ============================================================================= + +/// Collect all Place identifiers from a destructure pattern. +/// Corresponds to TS `eachPatternOperand`. +fn each_pattern_operand(pattern: &Pattern) -> Vec { + visitors::each_pattern_operand(pattern) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Collect all operand identifiers from an instruction value. +/// Corresponds to TS `eachInstructionValueOperand`. +fn each_instruction_value_operand( + value: &InstructionValue, + env: &Environment, +) -> Vec { + visitors::each_instruction_value_operand(value, env) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +// ============================================================================= +// findDisjointMutableValues +// ============================================================================= + +/// Find disjoint sets of co-mutating identifier IDs. +/// +/// Corresponds to TS `findDisjointMutableValues(fn: HIRFunction): +/// DisjointSet`. +pub(crate) fn find_disjoint_mutable_values( + func: &HirFunction, + env: &Environment, +) -> DisjointSet { + let mut scope_identifiers = DisjointSet::::new(); + let mut declarations: HashMap = HashMap::new(); + + let enable_forest = env.config.enable_forest; + + for (_block_id, block) in &func.body.blocks { + // Handle phi nodes + for phi in &block.phis { + let phi_id = phi.place.identifier; + let phi_range = &env.identifiers[phi_id.0 as usize].mutable_range; + let phi_decl_id = env.identifiers[phi_id.0 as usize].declaration_id; + + let first_instr_id = block + .instructions + .first() + .map(|iid| func.instructions[iid.0 as usize].id) + .unwrap_or(block.terminal.evaluation_order()); + + if phi_range.start.0 + 1 != phi_range.end.0 && phi_range.end > first_instr_id { + let mut operands = vec![phi_id]; + if let Some(&decl_id) = declarations.get(&phi_decl_id) { + operands.push(decl_id); + } + for (_pred_id, phi_operand) in &phi.operands { + operands.push(phi_operand.identifier); + } + scope_identifiers.union(&operands); + } else if enable_forest { + for (_pred_id, phi_operand) in &phi.operands { + scope_identifiers.union(&[phi_id, phi_operand.identifier]); + } + } + } + + // Handle instructions + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let mut operands: Vec = Vec::new(); + + let lvalue_id = instr.lvalue.identifier; + let lvalue_range = &env.identifiers[lvalue_id.0 as usize].mutable_range; + let lvalue_type = &env.types[env.identifiers[lvalue_id.0 as usize].type_.0 as usize]; + let lvalue_type_is_primitive = react_compiler_hir::is_primitive_type(lvalue_type); + + if lvalue_range.end.0 > lvalue_range.start.0 + 1 + || may_allocate(&instr.value, lvalue_type_is_primitive) + { + operands.push(lvalue_id); + } + + match &instr.value { + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::DeclareContext { lvalue, .. } => { + let place_id = lvalue.place.identifier; + let decl_id = env.identifiers[place_id.0 as usize].declaration_id; + declarations.entry(decl_id).or_insert(place_id); + } + InstructionValue::StoreLocal { lvalue, value, .. } + | InstructionValue::StoreContext { lvalue, value, .. } => { + let place_id = lvalue.place.identifier; + let decl_id = env.identifiers[place_id.0 as usize].declaration_id; + declarations.entry(decl_id).or_insert(place_id); + + let place_range = &env.identifiers[place_id.0 as usize].mutable_range; + if place_range.end.0 > place_range.start.0 + 1 { + operands.push(place_id); + } + + let value_range = &env.identifiers[value.identifier.0 as usize].mutable_range; + if value_range.contains(instr.id) && value_range.start.0 > 0 { + operands.push(value.identifier); + } + } + InstructionValue::Destructure { lvalue, value, .. } => { + let pattern_places = each_pattern_operand(&lvalue.pattern); + for place_id in &pattern_places { + let decl_id = env.identifiers[place_id.0 as usize].declaration_id; + declarations.entry(decl_id).or_insert(*place_id); + + let place_range = &env.identifiers[place_id.0 as usize].mutable_range; + if place_range.end.0 > place_range.start.0 + 1 { + operands.push(*place_id); + } + } + + let value_range = &env.identifiers[value.identifier.0 as usize].mutable_range; + if value_range.contains(instr.id) && value_range.start.0 > 0 { + operands.push(value.identifier); + } + } + InstructionValue::MethodCall { property, .. } => { + // For MethodCall: include all mutable operands plus the computed property + let all_operands = each_instruction_value_operand(&instr.value, env); + for op_id in &all_operands { + let op_range = &env.identifiers[op_id.0 as usize].mutable_range; + if op_range.contains(instr.id) && op_range.start.0 > 0 { + operands.push(*op_id); + } + } + // Ensure method property is in the same scope as the call + operands.push(property.identifier); + } + _ => { + // For all other instructions: include mutable operands + let all_operands = each_instruction_value_operand(&instr.value, env); + for op_id in &all_operands { + let op_range = &env.identifiers[op_id.0 as usize].mutable_range; + if op_range.contains(instr.id) && op_range.start.0 > 0 { + operands.push(*op_id); + } + } + } + } + + if !operands.is_empty() { + scope_identifiers.union(&operands); + } + } + } + scope_identifiers +} diff --git a/crates/react_compiler_inference/src/lib.rs b/crates/react_compiler_inference/src/lib.rs new file mode 100644 index 000000000000..ddf8596e8ba2 --- /dev/null +++ b/crates/react_compiler_inference/src/lib.rs @@ -0,0 +1,33 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod align_method_call_scopes; +pub mod align_object_method_scopes; +pub mod align_reactive_scopes_to_block_scopes_hir; +pub mod analyse_functions; +pub mod build_reactive_scope_terminals_hir; +pub mod flatten_reactive_loops_hir; +pub mod flatten_scopes_with_hooks_or_use_hir; +pub mod infer_mutation_aliasing_effects; +pub mod infer_mutation_aliasing_ranges; +pub mod infer_reactive_places; +pub mod infer_reactive_scope_variables; +pub mod memoize_fbt_and_macro_operands_in_same_scope; +pub mod merge_overlapping_reactive_scopes_hir; +pub mod propagate_scope_dependencies_hir; + +pub use align_method_call_scopes::align_method_call_scopes; +pub use align_object_method_scopes::align_object_method_scopes; +pub use align_reactive_scopes_to_block_scopes_hir::align_reactive_scopes_to_block_scopes_hir; +pub use analyse_functions::analyse_functions; +pub use build_reactive_scope_terminals_hir::build_reactive_scope_terminals_hir; +pub use flatten_reactive_loops_hir::flatten_reactive_loops_hir; +pub use flatten_scopes_with_hooks_or_use_hir::flatten_scopes_with_hooks_or_use_hir; +pub use infer_mutation_aliasing_effects::infer_mutation_aliasing_effects; +pub use infer_mutation_aliasing_ranges::infer_mutation_aliasing_ranges; +pub use infer_reactive_places::infer_reactive_places; +pub use infer_reactive_scope_variables::infer_reactive_scope_variables; +pub use memoize_fbt_and_macro_operands_in_same_scope::memoize_fbt_and_macro_operands_in_same_scope; +pub use merge_overlapping_reactive_scopes_hir::merge_overlapping_reactive_scopes_hir; +pub use propagate_scope_dependencies_hir::propagate_scope_dependencies_hir; diff --git a/crates/react_compiler_inference/src/memoize_fbt_and_macro_operands_in_same_scope.rs b/crates/react_compiler_inference/src/memoize_fbt_and_macro_operands_in_same_scope.rs new file mode 100644 index 000000000000..5233513eadfc --- /dev/null +++ b/crates/react_compiler_inference/src/memoize_fbt_and_macro_operands_in_same_scope.rs @@ -0,0 +1,364 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Port of MemoizeFbtAndMacroOperandsInSameScope from TypeScript. +//! +//! Ensures that FBT (Facebook Translation) expressions and their operands +//! are memoized within the same reactive scope. Also supports user-configured +//! custom macro-like APIs via `customMacros` configuration. +//! +//! The pass has two phases: +//! 1. Forward data-flow: identify all macro tags (including property loads like +//! `fbt.param`) +//! 2. Reverse data-flow: merge arguments of macro invocations into the same +//! scope + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + environment::Environment, visitors, HirFunction, IdentifierId, InstructionValue, JsxTag, + PrimitiveValue, PropertyLiteral, ScopeId, +}; + +/// Whether a macro requires its arguments to be transitively inlined (e.g., +/// fbt) or just avoids having the top-level values be converted to variables +/// (e.g., fbt.param). +#[derive(Debug, Clone)] +enum InlineLevel { + Transitive, + Shallow, +} + +/// Defines how a macro and its properties should be handled. +#[derive(Debug, Clone)] +struct MacroDefinition { + level: InlineLevel, + /// Maps property names to their own MacroDefinition. `"*"` is a wildcard. + properties: Option>, +} + +fn shallow_macro() -> MacroDefinition { + MacroDefinition { + level: InlineLevel::Shallow, + properties: None, + } +} + +fn transitive_macro() -> MacroDefinition { + MacroDefinition { + level: InlineLevel::Transitive, + properties: None, + } +} + +fn fbt_macro() -> MacroDefinition { + let mut props = HashMap::new(); + props.insert("*".to_string(), shallow_macro()); + // fbt.enum gets FBT_MACRO (recursive/transitive) + // We'll fill this in after construction since it's self-referential. + // Instead, we use a special marker and handle it in property lookup. + let mut fbt = MacroDefinition { + level: InlineLevel::Transitive, + properties: Some(props), + }; + // Add "enum" as a recursive reference (same as FBT_MACRO) + // Since we can't do self-referential structs, we clone the structure. + let enum_macro = MacroDefinition { + level: InlineLevel::Transitive, + properties: Some({ + let mut p = HashMap::new(); + p.insert("*".to_string(), shallow_macro()); + // enum's enum is also recursive, but in practice the depth is bounded + p.insert("enum".to_string(), transitive_macro()); + p + }), + }; + fbt.properties + .as_mut() + .unwrap() + .insert("enum".to_string(), enum_macro); + fbt +} + +/// Built-in FBT tags and their macro definitions. +fn fbt_tags() -> HashMap { + let mut tags = HashMap::new(); + tags.insert("fbt".to_string(), fbt_macro()); + tags.insert("fbt:param".to_string(), shallow_macro()); + tags.insert("fbt:enum".to_string(), fbt_macro()); + tags.insert("fbt:plural".to_string(), shallow_macro()); + tags.insert("fbs".to_string(), fbt_macro()); + tags.insert("fbs:param".to_string(), shallow_macro()); + tags.insert("fbs:enum".to_string(), fbt_macro()); + tags.insert("fbs:plural".to_string(), shallow_macro()); + tags +} + +/// Main entry point. Returns the set of identifier IDs that are fbt/macro +/// operands. +pub fn memoize_fbt_and_macro_operands_in_same_scope( + func: &HirFunction, + env: &mut Environment, +) -> HashSet { + // Phase 1: Build macro kinds map from built-in FBT tags + custom macros + let mut macro_kinds: HashMap = fbt_tags(); + if let Some(ref custom_macros) = env.config.custom_macros { + for name in custom_macros { + macro_kinds.insert(name.clone(), transitive_macro()); + } + } + + // Phase 2: Forward data-flow to identify all macro tags + let mut macro_tags = populate_macro_tags(func, ¯o_kinds); + + // Phase 3: Reverse data-flow to merge arguments of macro invocations + let macro_values = merge_macro_arguments(func, env, &mut macro_tags, ¯o_kinds); + + macro_values +} + +/// Forward data-flow analysis to identify all macro tags, including +/// things like `fbt.foo.bar(...)`. +fn populate_macro_tags( + func: &HirFunction, + macro_kinds: &HashMap, +) -> HashMap { + let mut macro_tags: HashMap = HashMap::new(); + + for block in func.body.blocks.values() { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + InstructionValue::Primitive { + value: PrimitiveValue::String(s), + .. + } => { + if let Some(macro_def) = macro_kinds.get(s.as_str()) { + // We don't distinguish between tag names and strings, so record + // all `fbt` string literals in case they are used as a jsx tag. + macro_tags.insert(lvalue_id, macro_def.clone()); + } + } + InstructionValue::LoadGlobal { binding, .. } => { + let name = binding.name(); + if let Some(macro_def) = macro_kinds.get(name) { + macro_tags.insert(lvalue_id, macro_def.clone()); + } + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + if let PropertyLiteral::String(prop_name) = property { + if let Some(macro_def) = macro_tags.get(&object.identifier).cloned() { + let property_macro = if let Some(ref props) = macro_def.properties { + let prop_def = + props.get(prop_name.as_str()).or_else(|| props.get("*")); + match prop_def { + Some(def) => def.clone(), + None => macro_def.clone(), + } + } else { + macro_def.clone() + }; + macro_tags.insert(lvalue_id, property_macro); + } + } + } + _ => {} + } + } + } + + macro_tags +} + +/// Reverse data-flow analysis to merge arguments to macro *invocations* +/// based on the kind of the macro. +fn merge_macro_arguments( + func: &HirFunction, + env: &mut Environment, + macro_tags: &mut HashMap, + macro_kinds: &HashMap, +) -> HashSet { + let mut macro_values: HashSet = macro_tags.keys().copied().collect(); + + // Iterate blocks in reverse order + let block_ids: Vec<_> = func.body.blocks.keys().copied().collect(); + for &block_id in block_ids.iter().rev() { + let block = &func.body.blocks[&block_id]; + + // Iterate instructions in reverse order + for &instr_id in block.instructions.iter().rev() { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + // Instructions that never need to be merged + InstructionValue::DeclareContext { .. } + | InstructionValue::DeclareLocal { .. } + | InstructionValue::Destructure { .. } + | InstructionValue::LoadContext { .. } + | InstructionValue::LoadLocal { .. } + | InstructionValue::PostfixUpdate { .. } + | InstructionValue::PrefixUpdate { .. } + | InstructionValue::StoreContext { .. } + | InstructionValue::StoreLocal { .. } => { + // Skip these + } + + InstructionValue::CallExpression { callee, .. } + | InstructionValue::MethodCall { + property: callee, .. + } => { + let scope_id = match env.identifiers[lvalue_id.0 as usize].scope { + Some(s) => s, + None => continue, + }; + + let macro_def = macro_tags + .get(&callee.identifier) + .or_else(|| macro_tags.get(&lvalue_id)) + .cloned(); + + if let Some(macro_def) = macro_def { + visit_operands( + ¯o_def, + scope_id, + lvalue_id, + &instr.value, + env, + &mut macro_values, + macro_tags, + ); + } + } + + InstructionValue::JsxExpression { tag, .. } => { + let scope_id = match env.identifiers[lvalue_id.0 as usize].scope { + Some(s) => s, + None => continue, + }; + + let macro_def = match tag { + JsxTag::Place(place) => macro_tags.get(&place.identifier).cloned(), + JsxTag::Builtin(builtin) => macro_kinds.get(builtin.name.as_str()).cloned(), + }; + + let macro_def = macro_def.or_else(|| macro_tags.get(&lvalue_id).cloned()); + + if let Some(macro_def) = macro_def { + visit_operands( + ¯o_def, + scope_id, + lvalue_id, + &instr.value, + env, + &mut macro_values, + macro_tags, + ); + } + } + + // Default case: check if lvalue is a macro tag + _ => { + let scope_id = match env.identifiers[lvalue_id.0 as usize].scope { + Some(s) => s, + None => continue, + }; + + let macro_def = macro_tags.get(&lvalue_id).cloned(); + if let Some(macro_def) = macro_def { + visit_operands( + ¯o_def, + scope_id, + lvalue_id, + &instr.value, + env, + &mut macro_values, + macro_tags, + ); + } + } + } + } + + // Handle phis + let block = &func.body.blocks[&block_id]; + for phi in &block.phis { + let scope_id = match env.identifiers[phi.place.identifier.0 as usize].scope { + Some(s) => s, + None => continue, + }; + + let macro_def = match macro_tags.get(&phi.place.identifier).cloned() { + Some(def) => def, + None => continue, + }; + + if matches!(macro_def.level, InlineLevel::Shallow) { + continue; + } + + macro_values.insert(phi.place.identifier); + + // Collect operand updates to avoid borrow issues + let operand_updates: Vec<(IdentifierId, MacroDefinition)> = phi + .operands + .values() + .map(|operand| (operand.identifier, macro_def.clone())) + .collect(); + + for (operand_id, def) in operand_updates { + env.identifiers[operand_id.0 as usize].scope = Some(scope_id); + expand_fbt_scope_range(env, scope_id, operand_id); + macro_tags.insert(operand_id, def); + macro_values.insert(operand_id); + } + } + } + + macro_values +} + +/// Expand the scope range on the environment, reading from identifier's +/// mutable_range. Equivalent to TS `expandFbtScopeRange`. +fn expand_fbt_scope_range(env: &mut Environment, scope_id: ScopeId, operand_id: IdentifierId) { + let extend_start = env.identifiers[operand_id.0 as usize].mutable_range.start; + if extend_start.0 != 0 { + let scope = &mut env.scopes[scope_id.0 as usize]; + scope.range.start.0 = scope.range.start.0.min(extend_start.0); + } +} + +/// Visit operands for an instruction value, merging them into the same scope +/// if the macro definition requires transitive inlining. +fn visit_operands( + macro_def: &MacroDefinition, + scope_id: ScopeId, + lvalue_id: IdentifierId, + value: &InstructionValue, + env: &mut Environment, + macro_values: &mut HashSet, + macro_tags: &mut HashMap, +) { + macro_values.insert(lvalue_id); + + // Collect operand IDs first to avoid borrow issues with env + let operand_ids: Vec = + visitors::each_instruction_value_operand_with_functions(value, &env.functions) + .into_iter() + .map(|p| p.identifier) + .collect(); + for operand_id in operand_ids { + if matches!(macro_def.level, InlineLevel::Transitive) { + env.identifiers[operand_id.0 as usize].scope = Some(scope_id); + expand_fbt_scope_range(env, scope_id, operand_id); + macro_tags.insert(operand_id, macro_def.clone()); + } + macro_values.insert(operand_id); + } +} diff --git a/crates/react_compiler_inference/src/merge_overlapping_reactive_scopes_hir.rs b/crates/react_compiler_inference/src/merge_overlapping_reactive_scopes_hir.rs new file mode 100644 index 000000000000..ad57f4aee7f2 --- /dev/null +++ b/crates/react_compiler_inference/src/merge_overlapping_reactive_scopes_hir.rs @@ -0,0 +1,420 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Merges reactive scopes that have overlapping ranges. +//! +//! While previous passes ensure that reactive scopes span valid sets of program +//! blocks, pairs of reactive scopes may still be inconsistent with respect to +//! each other. Two scopes must either be entirely disjoint or one must be +//! nested within the other. This pass detects overlapping scopes and merges +//! them. +//! +//! Additionally, if an instruction mutates an outer scope while a different +//! scope is active, those scopes are merged. +//! +//! Ported from TypeScript `src/HIR/MergeOverlappingReactiveScopesHIR.ts`. + +use std::{cmp, collections::HashMap}; + +use react_compiler_hir::{ + environment::Environment, + visitors, + visitors::{each_instruction_lvalue_ids, each_terminal_operand_ids}, + EvaluationOrder, HirFunction, IdentifierId, InstructionValue, ScopeId, Type, +}; +use react_compiler_utils::DisjointSet; + +// ============================================================================= +// ScopeInfo +// ============================================================================= + +struct ScopeStartEntry { + id: EvaluationOrder, + scopes: Vec, +} + +struct ScopeEndEntry { + id: EvaluationOrder, + scopes: Vec, +} + +struct ScopeInfo { + /// Sorted descending by id (so we can pop from the end for smallest) + scope_starts: Vec, + /// Sorted descending by id (so we can pop from the end for smallest) + scope_ends: Vec, + /// Maps IdentifierId -> ScopeId for all places that have a scope + place_scopes: HashMap, +} + +// ============================================================================= +// TraversalState +// ============================================================================= + +struct TraversalState { + joined: DisjointSet, + active_scopes: Vec, +} + +// ============================================================================= +// Helper functions +// ============================================================================= + +/// Check if a scope is active at the given instruction id. +/// Corresponds to TS `isScopeActive(scope, id)`. +fn is_scope_active(env: &Environment, scope_id: ScopeId, id: EvaluationOrder) -> bool { + env.scopes[scope_id.0 as usize].range.contains(id) +} + +/// Get the scope for a place if it's active at the given instruction. +/// Corresponds to TS `getPlaceScope(id, place)`. +fn get_place_scope( + env: &Environment, + id: EvaluationOrder, + identifier_id: IdentifierId, +) -> Option { + let scope_id = env.identifiers[identifier_id.0 as usize].scope?; + if is_scope_active(env, scope_id, id) { + Some(scope_id) + } else { + None + } +} + +/// Check if a place is mutable at the given instruction. +/// Corresponds to TS `isMutable({id}, place)`. +fn is_mutable(env: &Environment, id: EvaluationOrder, identifier_id: IdentifierId) -> bool { + let range = &env.identifiers[identifier_id.0 as usize].mutable_range; + range.contains(id) +} + +// ============================================================================= +// collectScopeInfo +// ============================================================================= + +fn collect_scope_info(func: &HirFunction, env: &Environment) -> ScopeInfo { + let mut scope_starts_map: HashMap> = HashMap::new(); + let mut scope_ends_map: HashMap> = HashMap::new(); + let mut place_scopes: HashMap = HashMap::new(); + + let mut collect_place_scope = |identifier_id: IdentifierId, env: &Environment| { + let scope_id = match env.identifiers[identifier_id.0 as usize].scope { + Some(s) => s, + None => return, + }; + place_scopes.insert(identifier_id, scope_id); + let range = &env.scopes[scope_id.0 as usize].range; + if range.start != range.end { + scope_starts_map + .entry(range.start) + .or_default() + .push(scope_id); + scope_ends_map.entry(range.end).or_default().push(scope_id); + } + }; + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + // lvalues + let lvalue_ids = each_instruction_lvalue_ids(instr); + for id in lvalue_ids { + collect_place_scope(id, env); + } + // operands + let operand_ids: Vec = visitors::each_instruction_operand(instr, env) + .into_iter() + .map(|p| p.identifier) + .collect(); + for id in operand_ids { + collect_place_scope(id, env); + } + } + // terminal operands + let terminal_op_ids = each_terminal_operand_ids(&block.terminal); + for id in terminal_op_ids { + collect_place_scope(id, env); + } + } + + // Deduplicate scope IDs in each entry, preserving insertion order. + // The TS uses Set which preserves insertion order and + // deduplicates. We must NOT sort by ScopeId here — the insertion order + // determines which scope becomes the root in the disjoint set union. + fn dedup_preserve_order(scopes: &mut Vec) { + let mut seen = std::collections::HashSet::new(); + scopes.retain(|s| seen.insert(*s)); + } + for scopes in scope_starts_map.values_mut() { + dedup_preserve_order(scopes); + } + for scopes in scope_ends_map.values_mut() { + dedup_preserve_order(scopes); + } + + // Convert to sorted vecs (descending by id for pop-from-end) + let mut scope_starts: Vec = scope_starts_map + .into_iter() + .map(|(id, scopes)| ScopeStartEntry { id, scopes }) + .collect(); + scope_starts.sort_by(|a, b| b.id.cmp(&a.id)); + + let mut scope_ends: Vec = scope_ends_map + .into_iter() + .map(|(id, scopes)| ScopeEndEntry { id, scopes }) + .collect(); + scope_ends.sort_by(|a, b| b.id.cmp(&a.id)); + + ScopeInfo { + scope_starts, + scope_ends, + place_scopes, + } +} + +// ============================================================================= +// visitInstructionId +// ============================================================================= + +fn visit_instruction_id( + id: EvaluationOrder, + scope_info: &mut ScopeInfo, + state: &mut TraversalState, + env: &Environment, +) { + // Handle all scopes that end at this instruction + if let Some(top) = scope_info.scope_ends.last() { + if top.id <= id { + let scope_end_entry = scope_info.scope_ends.pop().unwrap(); + + // Sort scopes by start descending (matching active_scopes order) + let mut scopes_sorted = scope_end_entry.scopes; + scopes_sorted.sort_by(|a, b| { + let a_start = env.scopes[a.0 as usize].range.start; + let b_start = env.scopes[b.0 as usize].range.start; + b_start.cmp(&a_start) + }); + + for scope in &scopes_sorted { + let idx = state.active_scopes.iter().position(|s| s == scope); + if let Some(idx) = idx { + // Detect and merge all overlapping scopes + if idx != state.active_scopes.len() - 1 { + let mut to_union: Vec = vec![*scope]; + to_union.extend_from_slice(&state.active_scopes[idx + 1..]); + state.joined.union(&to_union); + } + state.active_scopes.remove(idx); + } + } + } + } + + // Handle all scopes that begin at this instruction + if let Some(top) = scope_info.scope_starts.last() { + if top.id <= id { + let scope_start_entry = scope_info.scope_starts.pop().unwrap(); + + // Sort by end descending + let mut scopes_sorted = scope_start_entry.scopes; + scopes_sorted.sort_by(|a, b| { + let a_end = env.scopes[a.0 as usize].range.end; + let b_end = env.scopes[b.0 as usize].range.end; + b_end.cmp(&a_end) + }); + + state.active_scopes.extend_from_slice(&scopes_sorted); + + // Merge all identical scopes (same start and end) + for i in 1..scopes_sorted.len() { + let prev = scopes_sorted[i - 1]; + let curr = scopes_sorted[i]; + if env.scopes[prev.0 as usize].range.end == env.scopes[curr.0 as usize].range.end { + state.joined.union(&[prev, curr]); + } + } + } + } +} + +// ============================================================================= +// visitPlace +// ============================================================================= + +fn visit_place( + id: EvaluationOrder, + identifier_id: IdentifierId, + state: &mut TraversalState, + env: &Environment, +) { + // If an instruction mutates an outer scope, flatten all scopes from top + // of the stack to the mutated outer scope + let place_scope = get_place_scope(env, id, identifier_id); + if let Some(scope_id) = place_scope { + if is_mutable(env, id, identifier_id) { + let place_scope_idx = state.active_scopes.iter().position(|s| *s == scope_id); + if let Some(idx) = place_scope_idx { + if idx != state.active_scopes.len() - 1 { + let mut to_union: Vec = vec![scope_id]; + to_union.extend_from_slice(&state.active_scopes[idx + 1..]); + state.joined.union(&to_union); + } + } + } + } +} + +// ============================================================================= +// getOverlappingReactiveScopes +// ============================================================================= + +fn get_overlapping_reactive_scopes( + func: &HirFunction, + env: &Environment, + mut scope_info: ScopeInfo, +) -> DisjointSet { + let mut state = TraversalState { + joined: DisjointSet::::new(), + active_scopes: Vec::new(), + }; + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + visit_instruction_id(instr.id, &mut scope_info, &mut state, env); + + // Visit operands + let is_func_or_method = matches!( + &instr.value, + InstructionValue::FunctionExpression { .. } | InstructionValue::ObjectMethod { .. } + ); + let operand_ids = each_instruction_operand_ids_with_types(instr, env); + for (op_id, type_) in &operand_ids { + if is_func_or_method && matches!(type_, Type::Primitive) { + continue; + } + visit_place(instr.id, *op_id, &mut state, env); + } + + // Visit lvalues + let lvalue_ids = each_instruction_lvalue_ids(instr); + for lvalue_id in lvalue_ids { + visit_place(instr.id, lvalue_id, &mut state, env); + } + } + + let terminal_id = block.terminal.evaluation_order(); + visit_instruction_id(terminal_id, &mut scope_info, &mut state, env); + + let terminal_op_ids = each_terminal_operand_ids(&block.terminal); + for op_id in terminal_op_ids { + visit_place(terminal_id, op_id, &mut state, env); + } + } + + state.joined +} + +// ============================================================================= +// Public API +// ============================================================================= + +/// Merges reactive scopes that have overlapping ranges. +/// +/// Corresponds to TS `mergeOverlappingReactiveScopesHIR(fn: HIRFunction): +/// void`. +pub fn merge_overlapping_reactive_scopes_hir(func: &mut HirFunction, env: &mut Environment) { + // Collect scope info + let scope_info = collect_scope_info(func, env); + + // Save place_scopes before moving scope_info + let place_scopes = scope_info.place_scopes.clone(); + + // Find overlapping scopes + let mut joined_scopes = get_overlapping_reactive_scopes(func, env, scope_info); + + // Merge scope ranges: collect all (scope, root) pairs, then update root ranges + // by accumulating min start / max end from all members of each group. + // This matches TS behavior where groupScope.range is updated in-place during + // iteration. + let mut scope_groups: Vec<(ScopeId, ScopeId)> = Vec::new(); + joined_scopes.for_each(|scope_id, root_id| { + if scope_id != root_id { + scope_groups.push((scope_id, root_id)); + } + }); + // Collect root scopes' ORIGINAL ranges BEFORE updating them. + // In TS, identifier.mutableRange shares the same object reference as + // scope.range. When scope.range is updated, ALL identifiers referencing + // that range object automatically see the new values — even identifiers + // whose scope was later set to null. In Rust, we must explicitly find and + // update identifiers whose mutable_range matches a root scope's original + // range. + let mut original_root_ranges: HashMap = + HashMap::new(); + for (_, root_id) in &scope_groups { + if !original_root_ranges.contains_key(root_id) { + let range = &env.scopes[root_id.0 as usize].range; + original_root_ranges.insert(*root_id, (range.start, range.end)); + } + } + + // Update root scope ranges + for (scope_id, root_id) in &scope_groups { + let scope_start = env.scopes[scope_id.0 as usize].range.start; + let scope_end = env.scopes[scope_id.0 as usize].range.end; + let root_range = &mut env.scopes[root_id.0 as usize].range; + root_range.start = EvaluationOrder(cmp::min(root_range.start.0, scope_start.0)); + root_range.end = EvaluationOrder(cmp::max(root_range.end.0, scope_end.0)); + } + // Sync mutable_range for ALL identifiers whose mutable_range matches the + // ORIGINAL range of a root scope that was updated. In TS, + // identifier.mutableRange shares the same object reference as scope.range, + // so when scope.range is updated, all identifiers referencing that range + // object automatically see the new values — even identifiers whose scope + // was later set to null. In Rust, we must explicitly find and update these. + for ident in &mut env.identifiers { + for (root_id, (orig_start, orig_end)) in &original_root_ranges { + if ident.mutable_range.start == *orig_start && ident.mutable_range.end == *orig_end { + let new_range = &env.scopes[root_id.0 as usize].range; + ident.mutable_range.start = new_range.start; + ident.mutable_range.end = new_range.end; + break; + } + } + } + + // Rewrite all references: for each place that had a scope, point to the merged + // root. Note: we intentionally do NOT update mutable_range for repointed + // identifiers, matching TS behavior where identifier.mutableRange still + // references the old scope's range object after scope repointing. + for (identifier_id, original_scope) in &place_scopes { + let next_scope = joined_scopes.find(*original_scope); + if next_scope != *original_scope { + env.identifiers[identifier_id.0 as usize].scope = Some(next_scope); + } + } +} + +// ============================================================================= +// Instruction visitor helpers (delegating to canonical visitors) +// ============================================================================= + +/// Collect operand IdentifierIds with their types from an instruction value. +/// Used to check for Primitive type on FunctionExpression/ObjectMethod +/// operands. +fn each_instruction_operand_ids_with_types( + instr: &react_compiler_hir::Instruction, + env: &Environment, +) -> Vec<(IdentifierId, Type)> { + visitors::each_instruction_operand(instr, env) + .into_iter() + .map(|p| { + let type_ = + env.types[env.identifiers[p.identifier.0 as usize].type_.0 as usize].clone(); + (p.identifier, type_) + }) + .collect() +} diff --git a/crates/react_compiler_inference/src/propagate_scope_dependencies_hir.rs b/crates/react_compiler_inference/src/propagate_scope_dependencies_hir.rs new file mode 100644 index 000000000000..9445a2a56d60 --- /dev/null +++ b/crates/react_compiler_inference/src/propagate_scope_dependencies_hir.rs @@ -0,0 +1,2337 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Propagates scope dependencies through the HIR, computing which values each +//! reactive scope depends on. +//! +//! Ported from TypeScript: +//! - `src/HIR/PropagateScopeDependenciesHIR.ts` +//! - `src/HIR/CollectOptionalChainDependencies.ts` +//! - `src/HIR/CollectHoistablePropertyLoads.ts` +//! - `src/HIR/DeriveMinimalDependenciesHIR.ts` + +use std::collections::{BTreeSet, HashMap, HashSet}; + +use indexmap::IndexMap; +use react_compiler_hir::{ + environment::Environment, + visitors, + visitors::{ScopeBlockInfo, ScopeBlockTraversal}, + BasicBlock, BlockId, DeclarationId, DependencyPathEntry, EvaluationOrder, FunctionId, + GotoVariant, HirFunction, IdentifierId, Instruction, InstructionId, InstructionKind, + InstructionValue, MutableRange, ParamPattern, Place, PlaceOrSpread, PropertyLiteral, + ReactFunctionType, ReactiveScopeDependency, ScopeId, Terminal, Type, +}; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Main entry point: propagate scope dependencies through the HIR. +/// Corresponds to TS `propagateScopeDependenciesHIR(fn)`. +pub fn propagate_scope_dependencies_hir(func: &mut HirFunction, env: &mut Environment) { + let used_outside_declaring_scope = find_temporaries_used_outside_declaring_scope(func, env); + let temporaries = collect_temporaries_sidemap(func, env, &used_outside_declaring_scope); + + let OptionalChainSidemap { + temporaries_read_in_optional, + processed_instrs_in_optional, + hoistable_objects, + } = collect_optional_chain_sidemap(func, env); + + let hoistable_property_loads = { + let (working, registry) = + collect_hoistable_and_propagate(func, env, &temporaries, &hoistable_objects); + // Convert to scope-keyed map with full dependency paths + let mut keyed: HashMap> = HashMap::new(); + for (_block_id, block) in &func.body.blocks { + if let Terminal::Scope { + scope, + block: inner_block, + .. + } = &block.terminal + { + if let Some(node_indices) = working.get(inner_block) { + let deps: Vec = node_indices + .iter() + .map(|&idx| registry.nodes[idx].full_path.clone()) + .collect(); + keyed.insert(*scope, deps); + } + } + } + keyed + }; + + // Merge temporaries + temporariesReadInOptional + let mut merged_temporaries = temporaries; + for (k, v) in temporaries_read_in_optional { + merged_temporaries.insert(k, v); + } + + let scope_deps = collect_dependencies( + func, + env, + &used_outside_declaring_scope, + &merged_temporaries, + &processed_instrs_in_optional, + ); + + // Derive the minimal set of hoistable dependencies for each scope. + for (scope_id, deps) in &scope_deps { + if deps.is_empty() { + continue; + } + + let hoistables = hoistable_property_loads.get(scope_id); + let hoistables = + hoistables.expect("[PropagateScopeDependencies] Scope not found in tracked blocks"); + + // Step 2: Calculate hoistable dependencies using the tree. + let mut tree = ReactiveScopeDependencyTreeHIR::new(hoistables.iter(), env); + for dep in deps { + tree.add_dependency(dep.clone(), env); + } + + // Step 3: Reduce dependencies to a minimal set. + let candidates = tree.derive_minimal_dependencies(env); + let scope = &mut env.scopes[scope_id.0 as usize]; + for candidate_dep in candidates { + let already_exists = scope.dependencies.iter().any(|existing_dep| { + let existing_decl_id = + env.identifiers[existing_dep.identifier.0 as usize].declaration_id; + let candidate_decl_id = + env.identifiers[candidate_dep.identifier.0 as usize].declaration_id; + existing_decl_id == candidate_decl_id + && are_equal_paths(&existing_dep.path, &candidate_dep.path) + }); + if !already_exists { + scope.dependencies.push(candidate_dep); + } + } + } +} + +fn are_equal_paths(a: &[DependencyPathEntry], b: &[DependencyPathEntry]) -> bool { + a.len() == b.len() + && a.iter() + .zip(b.iter()) + .all(|(ai, bi)| ai.property == bi.property && ai.optional == bi.optional) +} + +// ============================================================================= +// findTemporariesUsedOutsideDeclaringScope +// ============================================================================= + +/// Corresponds to TS `findTemporariesUsedOutsideDeclaringScope`. +fn find_temporaries_used_outside_declaring_scope( + func: &HirFunction, + env: &Environment, +) -> HashSet { + let mut declarations: HashMap = HashMap::new(); + let mut pruned_scopes: HashSet = HashSet::new(); + let mut traversal = ScopeBlockTraversal::new(); + let mut used_outside_declaring_scope: HashSet = HashSet::new(); + + let handle_place = |place_id: IdentifierId, + declarations: &HashMap, + traversal: &ScopeBlockTraversal, + pruned_scopes: &HashSet, + used_outside: &mut HashSet, + env: &Environment| { + let decl_id = env.identifiers[place_id.0 as usize].declaration_id; + if let Some(&declaring_scope) = declarations.get(&decl_id) { + if !traversal.is_scope_active(declaring_scope) + && !pruned_scopes.contains(&declaring_scope) + { + used_outside.insert(decl_id); + } + } + }; + + for (block_id, block) in &func.body.blocks { + // recordScopes + traversal.record_scopes(block); + + let scope_start_info = traversal.block_infos.get(block_id); + if let Some(ScopeBlockInfo::Begin { + scope, + pruned: true, + .. + }) = scope_start_info + { + pruned_scopes.insert(*scope); + } + + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + // Handle operands + for op_id in visitors::each_instruction_operand(instr, env) + .into_iter() + .map(|p| p.identifier) + .collect::>() + { + handle_place( + op_id, + &declarations, + &traversal, + &pruned_scopes, + &mut used_outside_declaring_scope, + env, + ); + } + // Handle instruction (track declarations) + let current_scope = traversal.current_scope(); + if let Some(scope) = current_scope { + if !pruned_scopes.contains(&scope) { + match &instr.value { + InstructionValue::LoadLocal { .. } + | InstructionValue::LoadContext { .. } + | InstructionValue::PropertyLoad { .. } => { + let decl_id = + env.identifiers[instr.lvalue.identifier.0 as usize].declaration_id; + declarations.insert(decl_id, scope); + } + _ => {} + } + } + } + } + + // Terminal operands + for op_id in visitors::each_terminal_operand(&block.terminal) + .into_iter() + .map(|p| p.identifier) + .collect::>() + { + handle_place( + op_id, + &declarations, + &traversal, + &pruned_scopes, + &mut used_outside_declaring_scope, + env, + ); + } + } + + used_outside_declaring_scope +} + +// ============================================================================= +// collectTemporariesSidemap +// ============================================================================= + +/// Corresponds to TS `collectTemporariesSidemap`. +fn collect_temporaries_sidemap( + func: &HirFunction, + env: &Environment, + used_outside_declaring_scope: &HashSet, +) -> HashMap { + let mut temporaries = HashMap::new(); + collect_temporaries_sidemap_impl( + func, + env, + used_outside_declaring_scope, + &mut temporaries, + None, + ); + temporaries +} + +/// Corresponds to TS `isLoadContextMutable`. +fn is_load_context_mutable( + value: &InstructionValue, + id: EvaluationOrder, + env: &Environment, +) -> bool { + if let InstructionValue::LoadContext { place, .. } = value { + if let Some(scope_id) = env.identifiers[place.identifier.0 as usize].scope { + let scope_range = &env.scopes[scope_id.0 as usize].range; + return id >= scope_range.end; + } + } + false +} + +/// Corresponds to TS `convertHoistedLValueKind` — returns None for non-hoisted +/// kinds. +fn convert_hoisted_lvalue_kind(kind: InstructionKind) -> Option { + match kind { + InstructionKind::HoistedLet => Some(InstructionKind::Let), + InstructionKind::HoistedConst => Some(InstructionKind::Const), + InstructionKind::HoistedFunction => Some(InstructionKind::Function), + _ => None, + } +} + +/// Recursive implementation. Corresponds to TS `collectTemporariesSidemapImpl`. +fn collect_temporaries_sidemap_impl( + func: &HirFunction, + env: &Environment, + used_outside_declaring_scope: &HashSet, + temporaries: &mut HashMap, + inner_fn_context: Option, +) { + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let instr_eval_order = if let Some(outer_id) = inner_fn_context { + outer_id + } else { + instr.id + }; + let lvalue_decl_id = env.identifiers[instr.lvalue.identifier.0 as usize].declaration_id; + let used_outside = used_outside_declaring_scope.contains(&lvalue_decl_id); + + match &instr.value { + InstructionValue::PropertyLoad { + object, + property, + loc, + .. + } if !used_outside => { + if inner_fn_context.is_none() || temporaries.contains_key(&object.identifier) { + let prop = get_property(object, property, false, *loc, temporaries, env); + temporaries.insert(instr.lvalue.identifier, prop); + } + } + InstructionValue::LoadLocal { place, loc, .. } + if env.identifiers[instr.lvalue.identifier.0 as usize] + .name + .is_none() + && env.identifiers[place.identifier.0 as usize].name.is_some() + && !used_outside => + { + if inner_fn_context.is_none() + || func + .context + .iter() + .any(|ctx| ctx.identifier == place.identifier) + { + temporaries.insert( + instr.lvalue.identifier, + ReactiveScopeDependency { + identifier: place.identifier, + reactive: place.reactive, + path: vec![], + loc: *loc, + }, + ); + } + } + value @ InstructionValue::LoadContext { place, loc, .. } + if is_load_context_mutable(value, instr_eval_order, env) + && env.identifiers[instr.lvalue.identifier.0 as usize] + .name + .is_none() + && env.identifiers[place.identifier.0 as usize].name.is_some() + && !used_outside => + { + if inner_fn_context.is_none() + || func + .context + .iter() + .any(|ctx| ctx.identifier == place.identifier) + { + temporaries.insert( + instr.lvalue.identifier, + ReactiveScopeDependency { + identifier: place.identifier, + reactive: place.reactive, + path: vec![], + loc: *loc, + }, + ); + } + } + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + let ctx = inner_fn_context.unwrap_or(instr.id); + collect_temporaries_sidemap_impl( + inner_func, + env, + used_outside_declaring_scope, + temporaries, + Some(ctx), + ); + } + _ => {} + } + } + } +} + +/// Corresponds to TS `getProperty`. +fn get_property( + object: &Place, + property_name: &PropertyLiteral, + optional: bool, + loc: Option, + temporaries: &HashMap, + _env: &Environment, +) -> ReactiveScopeDependency { + let resolved = temporaries.get(&object.identifier); + if let Some(resolved) = resolved { + let mut path = resolved.path.clone(); + path.push(DependencyPathEntry { + property: property_name.clone(), + optional, + loc, + }); + ReactiveScopeDependency { + identifier: resolved.identifier, + reactive: resolved.reactive, + path, + loc, + } + } else { + ReactiveScopeDependency { + identifier: object.identifier, + reactive: object.reactive, + path: vec![DependencyPathEntry { + property: property_name.clone(), + optional, + loc, + }], + loc, + } + } +} + +// ============================================================================= +// CollectOptionalChainDependencies +// ============================================================================= + +struct OptionalChainSidemap { + temporaries_read_in_optional: HashMap, + processed_instrs_in_optional: HashSet, + hoistable_objects: HashMap, +} + +/// We track processed instructions/terminals by their lvalue IdentifierId + +/// block id. In TS this uses reference identity (Set). +/// We use IdentifierId for instructions (globally unique across functions) and +/// BlockId for terminals. Note: EvaluationOrder (instruction id) is NOT unique +/// across functions, so we cannot use it here. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum ProcessedInstr { + Instruction(IdentifierId), + Terminal(BlockId), +} + +fn collect_optional_chain_sidemap(func: &HirFunction, env: &Environment) -> OptionalChainSidemap { + let mut ctx = OptionalTraversalContext { + seen_optionals: HashSet::new(), + processed_instrs_in_optional: HashSet::new(), + temporaries_read_in_optional: HashMap::new(), + hoistable_objects: HashMap::new(), + }; + + traverse_function_optional(func, env, &mut ctx); + + OptionalChainSidemap { + temporaries_read_in_optional: ctx.temporaries_read_in_optional, + processed_instrs_in_optional: ctx.processed_instrs_in_optional, + hoistable_objects: ctx.hoistable_objects, + } +} + +struct OptionalTraversalContext { + seen_optionals: HashSet, + processed_instrs_in_optional: HashSet, + temporaries_read_in_optional: HashMap, + hoistable_objects: HashMap, +} + +fn traverse_function_optional( + func: &HirFunction, + env: &Environment, + ctx: &mut OptionalTraversalContext, +) { + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + traverse_function_optional(inner_func, env, ctx); + } + _ => {} + } + } + if let Terminal::Optional { .. } = &block.terminal { + if !ctx.seen_optionals.contains(&block.id) { + traverse_optional_block(block, func, env, ctx, None); + } + } + } +} + +struct MatchConsequentResult { + consequent_id: IdentifierId, + property: PropertyLiteral, + property_id: IdentifierId, + store_local_lvalue_id: IdentifierId, + consequent_goto: BlockId, + property_load_loc: Option, +} + +fn match_optional_test_block( + test: &Terminal, + func: &HirFunction, + _env: &Environment, +) -> Option { + let (test_place, consequent_block_id, alternate_block_id) = match test { + Terminal::Branch { + test, + consequent, + alternate, + .. + } => (test, *consequent, *alternate), + _ => return None, + }; + + let consequent_block = func.body.blocks.get(&consequent_block_id)?; + if consequent_block.instructions.len() != 2 { + return None; + } + + let instr0 = &func.instructions[consequent_block.instructions[0].0 as usize]; + let instr1 = &func.instructions[consequent_block.instructions[1].0 as usize]; + + let (property_load_object, property, property_load_loc) = match &instr0.value { + InstructionValue::PropertyLoad { + object, + property, + loc, + } => (object, property, loc), + _ => return None, + }; + + let store_local_value = match &instr1.value { + InstructionValue::StoreLocal { value, lvalue, .. } => { + // Verify the store local's value matches the property load's lvalue + if value.identifier != instr0.lvalue.identifier { + return None; + } + &lvalue.place + } + _ => return None, + }; + + // Verify property load's object matches the test + if property_load_object.identifier != test_place.identifier { + return None; + } + + // Check consequent block terminal is goto break + match &consequent_block.terminal { + Terminal::Goto { + variant: GotoVariant::Break, + block: goto_block, + .. + } => { + // Verify alternate block structure + let alternate_block = func.body.blocks.get(&alternate_block_id)?; + if alternate_block.instructions.len() != 2 { + return None; + } + let alt_instr0 = &func.instructions[alternate_block.instructions[0].0 as usize]; + let alt_instr1 = &func.instructions[alternate_block.instructions[1].0 as usize]; + match (&alt_instr0.value, &alt_instr1.value) { + (InstructionValue::Primitive { .. }, InstructionValue::StoreLocal { .. }) => {} + _ => return None, + } + + Some(MatchConsequentResult { + consequent_id: store_local_value.identifier, + property: property.clone(), + property_id: instr0.lvalue.identifier, + store_local_lvalue_id: instr1.lvalue.identifier, + consequent_goto: *goto_block, + property_load_loc: *property_load_loc, + }) + } + _ => None, + } +} + +fn traverse_optional_block( + optional_block: &BasicBlock, + func: &HirFunction, + env: &Environment, + ctx: &mut OptionalTraversalContext, + outer_alternate: Option, +) -> Option { + ctx.seen_optionals.insert(optional_block.id); + + let (test_block_id, is_optional, fallthrough_block_id) = match &optional_block.terminal { + Terminal::Optional { + test, + optional, + fallthrough, + .. + } => (*test, *optional, *fallthrough), + _ => return None, + }; + + let maybe_test_block = func.body.blocks.get(&test_block_id)?; + + let (test_terminal, base_object) = match &maybe_test_block.terminal { + Terminal::Branch { .. } => { + // Base case: optional must be true + if !is_optional { + return None; + } + // Match base expression that is straightforward PropertyLoad chain + if maybe_test_block.instructions.is_empty() { + return None; + } + let first_instr = &func.instructions[maybe_test_block.instructions[0].0 as usize]; + if !matches!(&first_instr.value, InstructionValue::LoadLocal { .. }) { + return None; + } + + let mut path: Vec = Vec::new(); + for i in 1..maybe_test_block.instructions.len() { + let curr_instr = &func.instructions[maybe_test_block.instructions[i].0 as usize]; + let prev_instr = + &func.instructions[maybe_test_block.instructions[i - 1].0 as usize]; + match &curr_instr.value { + InstructionValue::PropertyLoad { + object, + property, + loc, + .. + } if object.identifier == prev_instr.lvalue.identifier => { + path.push(DependencyPathEntry { + property: property.clone(), + optional: false, + loc: *loc, + }); + } + _ => return None, + } + } + + // Verify test expression matches last instruction's lvalue + let last_instr_id = *maybe_test_block.instructions.last().unwrap(); + let last_instr = &func.instructions[last_instr_id.0 as usize]; + let test_ident = match &maybe_test_block.terminal { + Terminal::Branch { test, .. } => test.identifier, + _ => return None, + }; + if test_ident != last_instr.lvalue.identifier { + return None; + } + + let first_place = match &first_instr.value { + InstructionValue::LoadLocal { place, .. } => place, + _ => return None, + }; + + let base = ReactiveScopeDependency { + identifier: first_place.identifier, + reactive: first_place.reactive, + path, + loc: first_place.loc, + }; + (&maybe_test_block.terminal, base) + } + Terminal::Optional { + fallthrough: inner_fallthrough, + optional: _inner_optional, + .. + } => { + let test_block = func.body.blocks.get(inner_fallthrough)?; + if !matches!(&test_block.terminal, Terminal::Branch { .. }) { + return None; + } + + // Recurse into inner optional + let inner_alternate = match &test_block.terminal { + Terminal::Branch { alternate, .. } => Some(*alternate), + _ => None, + }; + let inner_optional_result = + traverse_optional_block(maybe_test_block, func, env, ctx, inner_alternate); + let inner_optional_id = inner_optional_result?; + + // Check that inner optional is part of the same chain + let test_ident = match &test_block.terminal { + Terminal::Branch { test, .. } => test.identifier, + _ => return None, + }; + if test_ident != inner_optional_id { + return None; + } + + if !is_optional { + // Non-optional load: record that PropertyLoads from inner optional are + // hoistable + if let Some(inner_dep) = ctx.temporaries_read_in_optional.get(&inner_optional_id) { + ctx.hoistable_objects + .insert(optional_block.id, inner_dep.clone()); + } + } + + let base = ctx + .temporaries_read_in_optional + .get(&inner_optional_id)? + .clone(); + (&test_block.terminal, base) + } + _ => return None, + }; + + // Verify alternate matches outer_alternate if present + if let Some(outer_alt) = outer_alternate { + let test_alternate = match test_terminal { + Terminal::Branch { alternate, .. } => *alternate, + _ => return None, + }; + if test_alternate == outer_alt { + // Verify optional block has no instructions + if !optional_block.instructions.is_empty() { + return None; + } + } + } + + let match_result = match_optional_test_block(test_terminal, func, env)?; + + // Verify consequent goto matches optional fallthrough + if match_result.consequent_goto != fallthrough_block_id { + return None; + } + + let load = ReactiveScopeDependency { + identifier: base_object.identifier, + reactive: base_object.reactive, + path: { + let mut p = base_object.path.clone(); + p.push(DependencyPathEntry { + property: match_result.property.clone(), + optional: is_optional, + loc: match_result.property_load_loc, + }); + p + }, + loc: match_result.property_load_loc, + }; + + ctx.processed_instrs_in_optional + .insert(ProcessedInstr::Instruction( + match_result.store_local_lvalue_id, + )); + ctx.processed_instrs_in_optional + .insert(ProcessedInstr::Terminal(match &test_terminal { + Terminal::Branch { .. } => { + // Find the block ID for this terminal + // The terminal belongs to either maybe_test_block or the fallthrough block of + // inner optional We need to identify which block this terminal + // belongs to. For the base case, it's test_block_id. + // For nested optional, it's the fallthrough block. + // We'll use the block_id approach based on what we know. + // Actually, we tracked the terminal by its block, so we need to find which + // block contains this terminal. Let's use a pragmatic approach: + // The test terminal we matched was from maybe_test_block or from the inner + // fallthrough block. We'll search for it. + + // For the base case (Branch terminal at maybe_test_block), block_id = + // test_block_id For the nested case, the test terminal is at + // the fallthrough block of inner optional In either case, we + // stored the terminal as test_terminal which comes from a known block. + // We need to find the block that owns this terminal. + + // Let's take a simpler approach: find the block whose terminal matches + // This is the block we got test_terminal from. + // In the first branch of the match, test_terminal = &maybe_test_block.terminal + // and maybe_test_block.id = test_block_id + // In the second branch, test_terminal = &test_block.terminal + // and test_block = func.body.blocks.get(inner_fallthrough) + // We can't easily tell which case we're in here since we're past the match. + + // Actually, since test_terminal is a reference to a terminal in a block, + // we can just look up which block it belongs to by finding blocks whose + // terminal pointer matches. But that's expensive. Instead, + // let's use the block approach and find the block from the + // terminal's properties. + + // For simplicity, use a sentinel approach: just check all blocks. + // This is O(n) but only happens for optional chains. + let mut found_block = BlockId(0); + for (bid, blk) in &func.body.blocks { + if std::ptr::eq(&blk.terminal, test_terminal) { + found_block = *bid; + break; + } + } + found_block + } + _ => BlockId(0), + })); + ctx.temporaries_read_in_optional + .insert(match_result.consequent_id, load.clone()); + ctx.temporaries_read_in_optional + .insert(match_result.property_id, load); + + Some(match_result.consequent_id) +} + +// ============================================================================= +// CollectHoistablePropertyLoads +// ============================================================================= + +#[derive(Debug, Clone)] +struct PropertyPathNode { + properties: HashMap, // index into registry + optional_properties: HashMap, // index into registry + #[allow(dead_code)] + parent: Option, + full_path: ReactiveScopeDependency, + has_optional: bool, + #[allow(dead_code)] + root: Option, +} + +struct PropertyPathRegistry { + nodes: Vec, + roots: HashMap, +} + +impl PropertyPathRegistry { + fn new() -> Self { + Self { + nodes: Vec::new(), + roots: HashMap::new(), + } + } + + fn get_or_create_identifier( + &mut self, + identifier_id: IdentifierId, + reactive: bool, + loc: Option, + ) -> usize { + if let Some(&idx) = self.roots.get(&identifier_id) { + return idx; + } + let idx = self.nodes.len(); + self.nodes.push(PropertyPathNode { + properties: HashMap::new(), + optional_properties: HashMap::new(), + parent: None, + full_path: ReactiveScopeDependency { + identifier: identifier_id, + reactive, + path: vec![], + loc, + }, + has_optional: false, + root: Some(identifier_id), + }); + self.roots.insert(identifier_id, idx); + idx + } + + fn get_or_create_property_entry( + &mut self, + parent_idx: usize, + entry: &DependencyPathEntry, + ) -> usize { + let map_key = entry.property.clone(); + let existing = if entry.optional { + self.nodes[parent_idx] + .optional_properties + .get(&map_key) + .copied() + } else { + self.nodes[parent_idx].properties.get(&map_key).copied() + }; + if let Some(idx) = existing { + return idx; + } + let parent_full_path = self.nodes[parent_idx].full_path.clone(); + let parent_has_optional = self.nodes[parent_idx].has_optional; + let idx = self.nodes.len(); + let mut new_path = parent_full_path.path.clone(); + new_path.push(entry.clone()); + self.nodes.push(PropertyPathNode { + properties: HashMap::new(), + optional_properties: HashMap::new(), + parent: Some(parent_idx), + full_path: ReactiveScopeDependency { + identifier: parent_full_path.identifier, + reactive: parent_full_path.reactive, + path: new_path, + loc: entry.loc, + }, + has_optional: parent_has_optional || entry.optional, + root: None, + }); + if entry.optional { + self.nodes[parent_idx] + .optional_properties + .insert(map_key, idx); + } else { + self.nodes[parent_idx].properties.insert(map_key, idx); + } + idx + } + + fn get_or_create_property(&mut self, dep: &ReactiveScopeDependency) -> usize { + let mut curr = self.get_or_create_identifier(dep.identifier, dep.reactive, dep.loc); + for entry in &dep.path { + curr = self.get_or_create_property_entry(curr, entry); + } + curr + } +} + +/// Reduces optional chains in a set of property path nodes. +/// +/// Any two optional chains with different operations (`.` vs `?.`) but the same +/// set of property string paths de-duplicates. If unconditional reads from +/// `` are hoistable (i.e., `` is in the set), we replace +/// `?.PROPERTY` with `.PROPERTY`. +/// +/// Port of `reduceMaybeOptionalChains` from CollectHoistablePropertyLoads.ts. +fn reduce_maybe_optional_chains(nodes: &mut BTreeSet, registry: &mut PropertyPathRegistry) { + // Collect indices of nodes that have optional in their path + let mut optional_chain_nodes: BTreeSet = nodes + .iter() + .copied() + .filter(|&idx| registry.nodes[idx].has_optional) + .collect(); + + if optional_chain_nodes.is_empty() { + return; + } + + loop { + let mut changed = false; + + // Collect the indices to process (snapshot to avoid borrow issues) + let to_process: Vec = optional_chain_nodes.iter().copied().collect(); + + for original_idx in to_process { + let full_path = registry.nodes[original_idx].full_path.clone(); + + let mut curr_node = registry.get_or_create_identifier( + full_path.identifier, + full_path.reactive, + full_path.loc, + ); + + for entry in &full_path.path { + // If the base is known to be non-null (in the set), replace optional with + // non-optional + let next_entry = if entry.optional && nodes.contains(&curr_node) { + DependencyPathEntry { + property: entry.property.clone(), + optional: false, + loc: entry.loc, + } + } else { + entry.clone() + }; + curr_node = registry.get_or_create_property_entry(curr_node, &next_entry); + } + + if curr_node != original_idx { + changed = true; + optional_chain_nodes.remove(&original_idx); + optional_chain_nodes.insert(curr_node); + nodes.remove(&original_idx); + nodes.insert(curr_node); + } + } + + if !changed { + break; + } + } +} + +#[derive(Debug, Clone)] +struct BlockInfo { + assumed_non_null_objects: BTreeSet, // indices into PropertyPathRegistry +} + +#[allow(dead_code)] +fn collect_hoistable_property_loads( + func: &HirFunction, + env: &Environment, + temporaries: &HashMap, + hoistable_from_optionals: &HashMap, +) -> HashMap { + let mut registry = PropertyPathRegistry::new(); + let known_immutable_identifiers: HashSet = if func.fn_type + == ReactFunctionType::Component + || func.fn_type == ReactFunctionType::Hook + { + func.params + .iter() + .filter_map(|p| match p { + ParamPattern::Place(place) => Some(place.identifier), + _ => None, + }) + .collect() + } else { + HashSet::new() + }; + + let assumed_invoked_fns = get_assumed_invoked_functions(func, env); + let ctx = CollectHoistableContext { + temporaries, + known_immutable_identifiers: &known_immutable_identifiers, + hoistable_from_optionals, + nested_fn_immutable_context: None, + assumed_invoked_fns: &assumed_invoked_fns, + }; + + collect_hoistable_property_loads_impl(func, env, &ctx, &mut registry) +} + +struct CollectHoistableContext<'a> { + temporaries: &'a HashMap, + known_immutable_identifiers: &'a HashSet, + hoistable_from_optionals: &'a HashMap, + nested_fn_immutable_context: Option<&'a HashSet>, + assumed_invoked_fns: &'a HashSet, +} + +fn is_immutable_at_instr( + identifier_id: IdentifierId, + instr_id: EvaluationOrder, + env: &Environment, + ctx: &CollectHoistableContext, +) -> bool { + if let Some(nested_ctx) = ctx.nested_fn_immutable_context { + return nested_ctx.contains(&identifier_id); + } + let ident = &env.identifiers[identifier_id.0 as usize]; + let mutable_at_instr = ident.mutable_range.end + > EvaluationOrder(ident.mutable_range.start.0 + 1) + && ident.scope.is_some() + && { + let scope = &env.scopes[ident.scope.unwrap().0 as usize]; + in_range(instr_id, &scope.range) + }; + !mutable_at_instr || ctx.known_immutable_identifiers.contains(&identifier_id) +} + +fn in_range(id: EvaluationOrder, range: &MutableRange) -> bool { + id >= range.start && id < range.end +} + +fn get_maybe_non_null_in_instruction( + value: &InstructionValue, + temporaries: &HashMap, +) -> Option { + match value { + InstructionValue::PropertyLoad { object, .. } => Some( + temporaries + .get(&object.identifier) + .cloned() + .unwrap_or_else(|| ReactiveScopeDependency { + identifier: object.identifier, + reactive: object.reactive, + path: vec![], + loc: object.loc, + }), + ), + InstructionValue::Destructure { value: val, .. } => { + temporaries.get(&val.identifier).cloned() + } + InstructionValue::ComputedLoad { object, .. } => { + temporaries.get(&object.identifier).cloned() + } + _ => None, + } +} + +#[allow(dead_code)] +fn collect_hoistable_property_loads_impl( + func: &HirFunction, + env: &Environment, + ctx: &CollectHoistableContext, + registry: &mut PropertyPathRegistry, +) -> HashMap { + let nodes = collect_non_nulls_in_blocks(func, env, ctx, registry); + let working = propagate_non_null(func, &nodes, registry); + // Return the propagated results, converting HashSet back to BlockInfo + working + .into_iter() + .map(|(k, v)| { + ( + k, + BlockInfo { + assumed_non_null_objects: v, + }, + ) + }) + .collect() +} + +/// Corresponds to TS `getAssumedInvokedFunctions`. +/// Returns the set of LoweredFunction FunctionIds that are assumed to be +/// invoked. The `temporaries` map is shared across recursive calls (matching TS +/// behavior where the same Map is passed to recursive invocations for inner +/// functions). +fn get_assumed_invoked_functions(func: &HirFunction, env: &Environment) -> HashSet { + let mut temporaries: HashMap)> = HashMap::new(); + get_assumed_invoked_functions_impl(func, env, &mut temporaries) +} + +fn get_assumed_invoked_functions_impl( + func: &HirFunction, + env: &Environment, + temporaries: &mut HashMap)>, +) -> HashSet { + let mut hoistable: HashSet = HashSet::new(); + + // Step 1: Collect identifier to function expression mappings + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } => { + temporaries + .insert(instr.lvalue.identifier, (lowered_func.func, HashSet::new())); + } + InstructionValue::StoreLocal { + value: val, lvalue, .. + } => { + if let Some(entry) = temporaries.get(&val.identifier).cloned() { + temporaries.insert(lvalue.place.identifier, entry); + } + } + InstructionValue::LoadLocal { place, .. } => { + if let Some(entry) = temporaries.get(&place.identifier).cloned() { + temporaries.insert(instr.lvalue.identifier, entry); + } + } + _ => {} + } + } + } + + // Step 2: Forward pass to analyze assumed function calls + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::CallExpression { callee, args, .. } => { + let callee_ty = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + let maybe_hook = env.get_hook_kind_for_type(callee_ty).ok().flatten(); + if let Some(entry) = temporaries.get(&callee.identifier) { + // Direct calls + hoistable.insert(entry.0); + } else if maybe_hook.is_some() { + // Assume arguments to all hooks are safe to invoke + for arg in args { + if let PlaceOrSpread::Place(p) = arg { + if let Some(entry) = temporaries.get(&p.identifier) { + hoistable.insert(entry.0); + } + } + } + } + } + InstructionValue::JsxExpression { + props, children, .. + } => { + // Assume JSX attributes and children are safe to invoke + for prop in props { + if let react_compiler_hir::JsxAttribute::Attribute { place, .. } = prop { + if let Some(entry) = temporaries.get(&place.identifier) { + hoistable.insert(entry.0); + } + } + } + if let Some(children) = children { + for child in children { + if let Some(entry) = temporaries.get(&child.identifier) { + hoistable.insert(entry.0); + } + } + } + } + InstructionValue::JsxFragment { children, .. } => { + for child in children { + if let Some(entry) = temporaries.get(&child.identifier) { + hoistable.insert(entry.0); + } + } + } + InstructionValue::FunctionExpression { lowered_func, .. } => { + // Recursively traverse into other function expressions + // TS passes the shared temporaries map to the recursive call + let inner_func = &env.functions[lowered_func.func.0 as usize]; + let lambdas_called = + get_assumed_invoked_functions_impl(inner_func, env, temporaries); + if let Some(entry) = temporaries.get_mut(&instr.lvalue.identifier) { + for called in lambdas_called { + entry.1.insert(called); + } + } + } + _ => {} + } + } + + // Assume directly returned functions are safe to call + if let Terminal::Return { value, .. } = &block.terminal { + if let Some(entry) = temporaries.get(&value.identifier) { + hoistable.insert(entry.0); + } + } + } + + // Step 3: Propagate assumed-invoked status through mayInvoke chains + let mut changed = true; + while changed { + changed = false; + // Two-phase: collect then insert + let mut to_add = Vec::new(); + for (_, (func_id, may_invoke)) in temporaries.iter() { + if hoistable.contains(func_id) { + for &called in may_invoke { + if !hoistable.contains(&called) { + to_add.push(called); + } + } + } + } + for id in to_add { + changed = true; + hoistable.insert(id); + } + if !changed { + break; + } + } + + hoistable +} + +fn collect_non_nulls_in_blocks( + func: &HirFunction, + env: &Environment, + ctx: &CollectHoistableContext, + registry: &mut PropertyPathRegistry, +) -> HashMap { + // Known non-null identifiers (e.g. component props) + let mut known_non_null: BTreeSet = BTreeSet::new(); + if func.fn_type == ReactFunctionType::Component && !func.params.is_empty() { + if let ParamPattern::Place(place) = &func.params[0] { + let node_idx = registry.get_or_create_identifier(place.identifier, true, place.loc); + known_non_null.insert(node_idx); + } + } + + let mut nodes: HashMap = HashMap::new(); + + for (block_id, block) in &func.body.blocks { + let mut assumed = known_non_null.clone(); + + // Check hoistable from optionals + if let Some(optional_chain) = ctx.hoistable_from_optionals.get(block_id) { + let node_idx = registry.get_or_create_property(optional_chain); + assumed.insert(node_idx); + } + + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + if let Some(path) = get_maybe_non_null_in_instruction(&instr.value, ctx.temporaries) { + let path_ident = path.identifier; + if is_immutable_at_instr(path_ident, instr.id, env, ctx) { + let node_idx = registry.get_or_create_property(&path); + assumed.insert(node_idx); + } + } + + // Handle StartMemoize deps for enablePreserveExistingMemoizationGuarantees + if env.enable_preserve_existing_memoization_guarantees { + if let InstructionValue::StartMemoize { + deps: Some(deps), .. + } = &instr.value + { + for dep in deps { + if let react_compiler_hir::ManualMemoDependencyRoot::NamedLocal { + value: val, + .. + } = &dep.root + { + if !is_immutable_at_instr(val.identifier, instr.id, env, ctx) { + continue; + } + for i in 0..dep.path.len() { + if dep.path[i].optional { + break; + } + let sub_dep = ReactiveScopeDependency { + identifier: val.identifier, + reactive: val.reactive, + path: dep.path[..i].to_vec(), + loc: dep.loc, + }; + let node_idx = registry.get_or_create_property(&sub_dep); + assumed.insert(node_idx); + } + } + } + } + } + + // Handle assumed-invoked inner functions + if let InstructionValue::FunctionExpression { lowered_func, .. } = &instr.value { + if ctx.assumed_invoked_fns.contains(&lowered_func.func) { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + // Build nested fn immutable context + let nested_fn_immutable_context: HashSet = + if ctx.nested_fn_immutable_context.is_some() { + // Already in a nested fn context, use existing + ctx.nested_fn_immutable_context.unwrap().clone() + } else { + inner_func + .context + .iter() + .filter(|place| { + is_immutable_at_instr(place.identifier, instr.id, env, ctx) + }) + .map(|place| place.identifier) + .collect() + }; + let inner_assumed = get_assumed_invoked_functions(inner_func, env); + let inner_ctx = CollectHoistableContext { + temporaries: ctx.temporaries, + known_immutable_identifiers: &HashSet::new(), + hoistable_from_optionals: ctx.hoistable_from_optionals, + nested_fn_immutable_context: Some(&nested_fn_immutable_context), + assumed_invoked_fns: &inner_assumed, + }; + let inner_nodes = + collect_non_nulls_in_blocks(inner_func, env, &inner_ctx, registry); + // Propagate non-null from inner function + let inner_working = propagate_non_null(inner_func, &inner_nodes, registry); + // Get hoistables from inner function's entry block (after propagation) + let inner_entry = inner_func.body.entry; + if let Some(inner_set) = inner_working.get(&inner_entry) { + for &node_idx in inner_set { + assumed.insert(node_idx); + } + } + } + } + } + + nodes.insert( + *block_id, + BlockInfo { + assumed_non_null_objects: assumed, + }, + ); + } + + nodes +} + +/// Recursive DFS propagation of non-null information through the CFG. +/// Uses 'active'/'done' state tracking to correctly handle cycles (backedges in +/// loops). +/// +/// Port of TS `propagateNonNull` which uses `recursivelyPropagateNonNull`. +/// Key insight: when computing the intersection of neighbor sets, only include +/// neighbors that are 'done' (not 'active'). Active neighbors are part of a +/// cycle and should be filtered out, allowing non-null info to propagate +/// through non-cyclic paths. +fn propagate_non_null( + func: &HirFunction, + nodes: &HashMap, + registry: &mut PropertyPathRegistry, +) -> HashMap> { + // Build successor map + let mut block_successors: HashMap> = HashMap::new(); + for (block_id, block) in &func.body.blocks { + for pred in &block.preds { + block_successors.entry(*pred).or_default().insert(*block_id); + } + } + + // Clone nodes into mutable working set + let mut working: HashMap> = nodes + .iter() + .map(|(k, v)| (*k, v.assumed_non_null_objects.clone())) + .collect(); + + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + let mut reversed_block_ids = block_ids.clone(); + reversed_block_ids.reverse(); + + for _ in 0..100 { + let mut changed = false; + + // Forward pass (using predecessors) + let mut traversal_state: HashMap = HashMap::new(); + for &block_id in &block_ids { + let block_changed = recursively_propagate_non_null( + block_id, + PropagationDirection::Forward, + &mut traversal_state, + &mut working, + func, + &block_successors, + registry, + ); + changed |= block_changed; + } + + // Backward pass (using successors) + traversal_state.clear(); + for &block_id in &reversed_block_ids { + let block_changed = recursively_propagate_non_null( + block_id, + PropagationDirection::Backward, + &mut traversal_state, + &mut working, + func, + &block_successors, + registry, + ); + changed |= block_changed; + } + + if !changed { + break; + } + } + + working +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TraversalState { + Active, + Done, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropagationDirection { + Forward, + Backward, +} + +fn recursively_propagate_non_null( + node_id: BlockId, + direction: PropagationDirection, + traversal_state: &mut HashMap, + working: &mut HashMap>, + func: &HirFunction, + block_successors: &HashMap>, + registry: &mut PropertyPathRegistry, +) -> bool { + // Avoid re-visiting computed or currently active nodes + if traversal_state.contains_key(&node_id) { + return false; + } + traversal_state.insert(node_id, TraversalState::Active); + + let neighbors: Vec = match direction { + PropagationDirection::Backward => block_successors + .get(&node_id) + .map(|s| s.iter().copied().collect()) + .unwrap_or_default(), + PropagationDirection::Forward => func + .body + .blocks + .get(&node_id) + .map(|b| b.preds.iter().copied().collect()) + .unwrap_or_default(), + }; + + let mut changed = false; + for &neighbor in &neighbors { + if !traversal_state.contains_key(&neighbor) { + let neighbor_changed = recursively_propagate_non_null( + neighbor, + direction, + traversal_state, + working, + func, + block_successors, + registry, + ); + changed |= neighbor_changed; + } + } + + // Compute intersection of 'done' neighbors only (filter out 'active' = cycle + // nodes) + let done_neighbor_sets: Vec> = neighbors + .iter() + .filter(|n| traversal_state.get(n) == Some(&TraversalState::Done)) + .filter_map(|n| working.get(n).cloned()) + .collect(); + + let neighbor_intersection = if done_neighbor_sets.is_empty() { + BTreeSet::new() + } else { + let mut iter = done_neighbor_sets.into_iter(); + let first = iter.next().unwrap(); + iter.fold(first, |acc, s| acc.intersection(&s).copied().collect()) + }; + + let prev_objects = working.get(&node_id).cloned().unwrap_or_default(); + let mut merged: BTreeSet = prev_objects + .union(&neighbor_intersection) + .copied() + .collect(); + reduce_maybe_optional_chains(&mut merged, registry); + + working.insert(node_id, merged.clone()); + traversal_state.insert(node_id, TraversalState::Done); + + // Compare with previous value — can't just check size due to + // reduce_maybe_optional_chains + changed |= prev_objects != merged; + changed +} + +fn collect_hoistable_and_propagate( + func: &HirFunction, + env: &Environment, + temporaries: &HashMap, + hoistable_from_optionals: &HashMap, +) -> (HashMap>, PropertyPathRegistry) { + let mut registry = PropertyPathRegistry::new(); + let assumed_invoked_fns = get_assumed_invoked_functions(func, env); + let known_immutable_identifiers: HashSet = if func.fn_type + == ReactFunctionType::Component + || func.fn_type == ReactFunctionType::Hook + { + func.params + .iter() + .filter_map(|p| match p { + ParamPattern::Place(place) => Some(place.identifier), + _ => None, + }) + .collect() + } else { + HashSet::new() + }; + + let ctx = CollectHoistableContext { + temporaries, + known_immutable_identifiers: &known_immutable_identifiers, + hoistable_from_optionals, + nested_fn_immutable_context: None, + assumed_invoked_fns: &assumed_invoked_fns, + }; + + let nodes = collect_non_nulls_in_blocks(func, env, &ctx, &mut registry); + let working = propagate_non_null(func, &nodes, &mut registry); + + (working, registry) +} + +// Restructured version used by the main entry point +#[allow(dead_code)] +fn key_by_scope_id( + func: &HirFunction, + block_keyed: &HashMap, +) -> HashMap { + let mut keyed: HashMap = HashMap::new(); + for (_block_id, block) in &func.body.blocks { + if let Terminal::Scope { + scope, + block: inner_block, + .. + } = &block.terminal + { + if let Some(info) = block_keyed.get(inner_block) { + keyed.insert(*scope, info.clone()); + } + } + } + keyed +} + +// ============================================================================= +// DeriveMinimalDependenciesHIR +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PropertyAccessType { + OptionalAccess, + UnconditionalAccess, + OptionalDependency, + UnconditionalDependency, +} + +fn is_optional_access(access: PropertyAccessType) -> bool { + matches!( + access, + PropertyAccessType::OptionalAccess | PropertyAccessType::OptionalDependency + ) +} + +fn is_dependency_access(access: PropertyAccessType) -> bool { + matches!( + access, + PropertyAccessType::OptionalDependency | PropertyAccessType::UnconditionalDependency + ) +} + +fn merge_access(a: PropertyAccessType, b: PropertyAccessType) -> PropertyAccessType { + let is_unconditional = !(is_optional_access(a) && is_optional_access(b)); + let is_dep = is_dependency_access(a) || is_dependency_access(b); + match (is_unconditional, is_dep) { + (true, true) => PropertyAccessType::UnconditionalDependency, + (true, false) => PropertyAccessType::UnconditionalAccess, + (false, true) => PropertyAccessType::OptionalDependency, + (false, false) => PropertyAccessType::OptionalAccess, + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum HoistableAccessType { + Optional, + NonNull, +} + +struct HoistableNode { + properties: HashMap>, + access_type: HoistableAccessType, +} + +struct HoistableNodeEntry { + node: HoistableNode, +} + +struct DependencyNode { + properties: IndexMap>, + access_type: PropertyAccessType, + loc: Option, +} + +struct DependencyNodeEntry { + node: DependencyNode, +} + +struct ReactiveScopeDependencyTreeHIR { + hoistable_roots: HashMap, // node + reactive + dep_roots: IndexMap, /* node + reactive (preserves + * insertion order like JS + * Map) */ +} + +impl ReactiveScopeDependencyTreeHIR { + fn new<'a>( + hoistable_objects: impl Iterator, + _env: &Environment, + ) -> Self { + let mut hoistable_roots: HashMap = HashMap::new(); + + for dep in hoistable_objects { + let root = hoistable_roots.entry(dep.identifier).or_insert_with(|| { + let access_type = if !dep.path.is_empty() && dep.path[0].optional { + HoistableAccessType::Optional + } else { + HoistableAccessType::NonNull + }; + ( + HoistableNode { + properties: HashMap::new(), + access_type, + }, + dep.reactive, + ) + }); + + let mut curr = &mut root.0; + for i in 0..dep.path.len() { + let access_type = if i + 1 < dep.path.len() && dep.path[i + 1].optional { + HoistableAccessType::Optional + } else { + HoistableAccessType::NonNull + }; + let entry = curr + .properties + .entry(dep.path[i].property.clone()) + .or_insert_with(|| { + Box::new(HoistableNodeEntry { + node: HoistableNode { + properties: HashMap::new(), + access_type, + }, + }) + }); + curr = &mut entry.node; + } + } + + Self { + hoistable_roots, + dep_roots: IndexMap::new(), + } + } + + fn add_dependency(&mut self, dep: ReactiveScopeDependency, _env: &Environment) { + let root = self.dep_roots.entry(dep.identifier).or_insert_with(|| { + ( + DependencyNode { + properties: IndexMap::new(), + access_type: PropertyAccessType::UnconditionalAccess, + loc: dep.loc, + }, + dep.reactive, + ) + }); + + let mut dep_cursor = &mut root.0; + let hoistable_cursor_root = self.hoistable_roots.get(&dep.identifier); + let mut hoistable_ptr: Option<&HoistableNode> = hoistable_cursor_root.map(|(n, _)| n); + + for entry in &dep.path { + let next_hoistable: Option<&HoistableNode>; + let access_type: PropertyAccessType; + + if entry.optional { + next_hoistable = + hoistable_ptr.and_then(|h| h.properties.get(&entry.property).map(|e| &e.node)); + + if hoistable_ptr.is_some() + && hoistable_ptr.unwrap().access_type == HoistableAccessType::NonNull + { + access_type = PropertyAccessType::UnconditionalAccess; + } else { + access_type = PropertyAccessType::OptionalAccess; + } + } else if hoistable_ptr.is_some() + && hoistable_ptr.unwrap().access_type == HoistableAccessType::NonNull + { + next_hoistable = + hoistable_ptr.and_then(|h| h.properties.get(&entry.property).map(|e| &e.node)); + access_type = PropertyAccessType::UnconditionalAccess; + } else { + // Break: truncate dependency + break; + } + + // make_or_merge_property + let child = dep_cursor + .properties + .entry(entry.property.clone()) + .or_insert_with(|| { + Box::new(DependencyNodeEntry { + node: DependencyNode { + properties: IndexMap::new(), + access_type, + loc: entry.loc, + }, + }) + }); + child.node.access_type = merge_access(child.node.access_type, access_type); + + dep_cursor = &mut child.node; + hoistable_ptr = next_hoistable; + } + + // Mark final node as dependency + dep_cursor.access_type = merge_access( + dep_cursor.access_type, + PropertyAccessType::OptionalDependency, + ); + } + + fn derive_minimal_dependencies(&self, _env: &Environment) -> Vec { + let mut results = Vec::new(); + for (&root_id, (root_node, reactive)) in &self.dep_roots { + collect_minimal_deps_in_subtree(root_node, *reactive, root_id, &[], &mut results); + } + results + } +} + +fn collect_minimal_deps_in_subtree( + node: &DependencyNode, + reactive: bool, + root_id: IdentifierId, + path: &[DependencyPathEntry], + results: &mut Vec, +) { + if is_dependency_access(node.access_type) { + results.push(ReactiveScopeDependency { + identifier: root_id, + reactive, + path: path.to_vec(), + loc: node.loc, + }); + } else { + for (child_name, child_entry) in &node.properties { + let mut new_path = path.to_vec(); + new_path.push(DependencyPathEntry { + property: child_name.clone(), + optional: is_optional_access(child_entry.node.access_type), + loc: child_entry.node.loc, + }); + collect_minimal_deps_in_subtree( + &child_entry.node, + reactive, + root_id, + &new_path, + results, + ); + } + } +} + +// ============================================================================= +// collectDependencies +// ============================================================================= + +/// A declaration record: instruction id + scope stack at declaration time. +#[derive(Clone)] +struct Decl { + id: EvaluationOrder, + scope_stack: Vec, // copy of the scope stack at time of declaration +} + +/// Context for dependency collection. +struct DependencyCollectionContext<'a> { + declarations: HashMap, + reassignments: HashMap, + scope_stack: Vec, + dep_stack: Vec>, + deps: IndexMap>, + temporaries: &'a HashMap, + #[allow(dead_code)] + temporaries_used_outside_scope: &'a HashSet, + processed_instrs_in_optional: &'a HashSet, + inner_fn_context: Option, +} + +impl<'a> DependencyCollectionContext<'a> { + fn new( + temporaries_used_outside_scope: &'a HashSet, + temporaries: &'a HashMap, + processed_instrs_in_optional: &'a HashSet, + ) -> Self { + Self { + declarations: HashMap::new(), + reassignments: HashMap::new(), + scope_stack: Vec::new(), + dep_stack: Vec::new(), + deps: IndexMap::new(), + temporaries, + temporaries_used_outside_scope, + processed_instrs_in_optional, + inner_fn_context: None, + } + } + + fn enter_scope(&mut self, scope_id: ScopeId) { + self.dep_stack.push(Vec::new()); + self.scope_stack.push(scope_id); + } + + fn exit_scope(&mut self, scope_id: ScopeId, pruned: bool, env: &mut Environment) { + let scoped_deps = self + .dep_stack + .pop() + .expect("[PropagateScopeDeps]: Unexpected scope mismatch"); + self.scope_stack.pop(); + + // Propagate dependencies upward + for dep in &scoped_deps { + if self.check_valid_dependency(dep, env) { + if let Some(top) = self.dep_stack.last_mut() { + top.push(dep.clone()); + } + } + } + + if !pruned { + self.deps.insert(scope_id, scoped_deps); + } + } + + fn current_scope(&self) -> Option { + self.scope_stack.last().copied() + } + + fn declare(&mut self, identifier_id: IdentifierId, decl: Decl, env: &Environment) { + if self.inner_fn_context.is_some() { + return; + } + let decl_id = env.identifiers[identifier_id.0 as usize].declaration_id; + if !self.declarations.contains_key(&decl_id) { + self.declarations.insert(decl_id, decl.clone()); + } + self.reassignments.insert(identifier_id, decl); + } + + fn has_declared(&self, identifier_id: IdentifierId, env: &Environment) -> bool { + let decl_id = env.identifiers[identifier_id.0 as usize].declaration_id; + self.declarations.contains_key(&decl_id) + } + + fn check_valid_dependency(&self, dep: &ReactiveScopeDependency, env: &Environment) -> bool { + // Ref value is not a valid dep + let ty = &env.types[env.identifiers[dep.identifier.0 as usize].type_.0 as usize]; + if react_compiler_hir::is_ref_value_type(ty) { + return false; + } + // Object methods are not deps + if matches!(ty, Type::ObjectMethod) { + return false; + } + + let ident = &env.identifiers[dep.identifier.0 as usize]; + let current_declaration = self + .reassignments + .get(&dep.identifier) + .or_else(|| self.declarations.get(&ident.declaration_id)); + + if let Some(current_scope) = self.current_scope() { + if let Some(decl) = current_declaration { + let scope_range_start = env.scopes[current_scope.0 as usize].range.start; + return decl.id < scope_range_start; + } + } + false + } + + fn visit_operand(&mut self, place: &Place, env: &mut Environment) { + let dep = self + .temporaries + .get(&place.identifier) + .cloned() + .unwrap_or_else(|| ReactiveScopeDependency { + identifier: place.identifier, + reactive: place.reactive, + path: vec![], + loc: place.loc, + }); + self.visit_dependency(dep, env); + } + + fn visit_property( + &mut self, + object: &Place, + property: &PropertyLiteral, + optional: bool, + loc: Option, + env: &mut Environment, + ) { + let dep = get_property(object, property, optional, loc, self.temporaries, env); + self.visit_dependency(dep, env); + } + + fn visit_dependency(&mut self, dep: ReactiveScopeDependency, env: &mut Environment) { + let ident = &env.identifiers[dep.identifier.0 as usize]; + let decl_id = ident.declaration_id; + + // Record scope declarations for values used outside their declaring scope + if let Some(original_decl) = self.declarations.get(&decl_id) { + if !original_decl.scope_stack.is_empty() { + let orig_scope_stack = original_decl.scope_stack.clone(); + for &scope_id in &orig_scope_stack { + if !self.scope_stack.contains(&scope_id) { + // Check if already declared in this scope + let scope = &env.scopes[scope_id.0 as usize]; + let already_declared = scope.declarations.iter().any(|(_, d)| { + env.identifiers[d.identifier.0 as usize].declaration_id == decl_id + }); + if !already_declared { + let orig_scope_id = *orig_scope_stack.last().unwrap(); + let new_decl = react_compiler_hir::ReactiveScopeDeclaration { + identifier: dep.identifier, + scope: orig_scope_id, + }; + env.scopes[scope_id.0 as usize] + .declarations + .push((dep.identifier, new_decl)); + } + } + } + } + } + + // Handle ref.current access + let dep = if react_compiler_hir::is_use_ref_type( + &env.types[env.identifiers[dep.identifier.0 as usize].type_.0 as usize], + ) && dep + .path + .first() + .map(|p| p.property == PropertyLiteral::String("current".to_string())) + .unwrap_or(false) + { + ReactiveScopeDependency { + identifier: dep.identifier, + reactive: dep.reactive, + path: vec![], + loc: dep.loc, + } + } else { + dep + }; + + if self.check_valid_dependency(&dep, env) { + if let Some(top) = self.dep_stack.last_mut() { + top.push(dep); + } + } + } + + fn visit_reassignment(&mut self, place: &Place, env: &mut Environment) { + if let Some(current_scope) = self.current_scope() { + let scope = &env.scopes[current_scope.0 as usize]; + let already = scope.reassignments.iter().any(|id| { + env.identifiers[id.0 as usize].declaration_id + == env.identifiers[place.identifier.0 as usize].declaration_id + }); + if !already + && self.check_valid_dependency( + &ReactiveScopeDependency { + identifier: place.identifier, + reactive: place.reactive, + path: vec![], + loc: place.loc, + }, + env, + ) + { + env.scopes[current_scope.0 as usize] + .reassignments + .push(place.identifier); + } + } + } + + fn is_deferred_dependency_instr(&self, instr: &Instruction) -> bool { + self.processed_instrs_in_optional + .contains(&ProcessedInstr::Instruction(instr.lvalue.identifier)) + || self.temporaries.contains_key(&instr.lvalue.identifier) + } + + fn is_deferred_dependency_terminal(&self, block_id: BlockId) -> bool { + self.processed_instrs_in_optional + .contains(&ProcessedInstr::Terminal(block_id)) + } +} + +/// Recursively visit an inner function's blocks, processing all instructions +/// including nested FunctionExpressions. This mirrors the TS pattern of +/// `context.enterInnerFn(instr, () => handleFunction(innerFn))`. +fn visit_inner_function_blocks( + func_id: FunctionId, + ctx: &mut DependencyCollectionContext, + env: &mut Environment, +) { + // Clone inner function's instructions and block structure to avoid + // borrow conflicts when mutating env through handle_instruction. + let inner_instrs: Vec = env.functions[func_id.0 as usize].instructions.clone(); + let inner_blocks: Vec<( + BlockId, + Vec, + Vec<(BlockId, IdentifierId)>, + Terminal, + )> = env.functions[func_id.0 as usize] + .body + .blocks + .iter() + .map(|(bid, blk)| { + let phi_ops: Vec<(BlockId, IdentifierId)> = blk + .phis + .iter() + .flat_map(|phi| { + phi.operands + .iter() + .map(|(pred, place)| (*pred, place.identifier)) + }) + .collect(); + ( + *bid, + blk.instructions.clone(), + phi_ops, + blk.terminal.clone(), + ) + }) + .collect(); + + for (inner_bid, inner_instr_ids, inner_phis, inner_terminal) in &inner_blocks { + for &(_pred_id, op_id) in inner_phis { + if let Some(maybe_optional) = ctx.temporaries.get(&op_id) { + ctx.visit_dependency(maybe_optional.clone(), env); + } + } + + for &iid in inner_instr_ids { + let inner_instr = &inner_instrs[iid.0 as usize]; + match &inner_instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + // Recursively visit nested function expressions + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + inner_instr.lvalue.identifier, + Decl { + id: inner_instr.id, + scope_stack: scope_stack_copy, + }, + env, + ); + visit_inner_function_blocks(lowered_func.func, ctx, env); + } + _ => { + handle_instruction(inner_instr, ctx, env); + } + } + } + + if !ctx.is_deferred_dependency_terminal(*inner_bid) { + let terminal_ops = visitors::each_terminal_operand(inner_terminal); + for op in &terminal_ops { + ctx.visit_operand(op, env); + } + } + } +} + +fn handle_instruction( + instr: &Instruction, + ctx: &mut DependencyCollectionContext, + env: &mut Environment, +) { + let id = instr.id; + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + instr.lvalue.identifier, + Decl { + id, + scope_stack: scope_stack_copy, + }, + env, + ); + + if ctx.is_deferred_dependency_instr(instr) { + return; + } + + match &instr.value { + InstructionValue::PropertyLoad { + object, + property, + loc, + .. + } => { + ctx.visit_property(object, property, false, *loc, env); + } + InstructionValue::StoreLocal { + value: val, lvalue, .. + } => { + ctx.visit_operand(val, env); + if lvalue.kind == InstructionKind::Reassign { + ctx.visit_reassignment(&lvalue.place, env); + } + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + lvalue.place.identifier, + Decl { + id, + scope_stack: scope_stack_copy, + }, + env, + ); + } + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::DeclareContext { lvalue, .. } => { + if convert_hoisted_lvalue_kind(lvalue.kind).is_none() { + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + lvalue.place.identifier, + Decl { + id, + scope_stack: scope_stack_copy, + }, + env, + ); + } + } + InstructionValue::Destructure { + value: val, lvalue, .. + } => { + ctx.visit_operand(val, env); + let pattern_places = visitors::each_pattern_operand(&lvalue.pattern); + for place in &pattern_places { + if lvalue.kind == InstructionKind::Reassign { + ctx.visit_reassignment(place, env); + } + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + place.identifier, + Decl { + id, + scope_stack: scope_stack_copy, + }, + env, + ); + } + } + InstructionValue::StoreContext { + lvalue, value: val, .. + } => { + if !ctx.has_declared(lvalue.place.identifier, env) + || lvalue.kind != InstructionKind::Reassign + { + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + lvalue.place.identifier, + Decl { + id, + scope_stack: scope_stack_copy, + }, + env, + ); + } + // Visit all operands (lvalue.place AND value) + ctx.visit_operand(&lvalue.place, env); + ctx.visit_operand(val, env); + } + _ => { + // Visit all value operands + let operands = visitors::each_instruction_value_operand(&instr.value, env); + for operand in &operands { + ctx.visit_operand(operand, env); + } + } + } +} + +fn collect_dependencies( + func: &HirFunction, + env: &mut Environment, + used_outside_declaring_scope: &HashSet, + temporaries: &HashMap, + processed_instrs_in_optional: &HashSet, +) -> IndexMap> { + let mut ctx = DependencyCollectionContext::new( + used_outside_declaring_scope, + temporaries, + processed_instrs_in_optional, + ); + + // Declare params + for param in &func.params { + match param { + ParamPattern::Place(place) => { + ctx.declare( + place.identifier, + Decl { + id: EvaluationOrder(0), + scope_stack: vec![], + }, + env, + ); + } + ParamPattern::Spread(spread) => { + ctx.declare( + spread.place.identifier, + Decl { + id: EvaluationOrder(0), + scope_stack: vec![], + }, + env, + ); + } + } + } + + let mut traversal = ScopeBlockTraversal::new(); + + handle_function_deps(func, env, &mut ctx, &mut traversal); + + ctx.deps +} + +fn handle_function_deps( + func: &HirFunction, + env: &mut Environment, + ctx: &mut DependencyCollectionContext, + traversal: &mut ScopeBlockTraversal, +) { + for (block_id, block) in &func.body.blocks { + // Record scopes + traversal.record_scopes(block); + + let scope_block_info = traversal.block_infos.get(block_id).cloned(); + match &scope_block_info { + Some(ScopeBlockInfo::Begin { scope, .. }) => { + ctx.enter_scope(*scope); + } + Some(ScopeBlockInfo::End { scope, pruned, .. }) => { + ctx.exit_scope(*scope, *pruned, env); + } + None => {} + } + + // Record phi operands + for phi in &block.phis { + for (_pred_id, operand) in &phi.operands { + if let Some(maybe_optional_chain) = ctx.temporaries.get(&operand.identifier) { + ctx.visit_dependency(maybe_optional_chain.clone(), env); + } + } + } + + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let scope_stack_copy = ctx.scope_stack.clone(); + ctx.declare( + instr.lvalue.identifier, + Decl { + id: instr.id, + scope_stack: scope_stack_copy, + }, + env, + ); + + // Recursively visit inner function + let inner_func_id = lowered_func.func; + let prev_inner = ctx.inner_fn_context; + if ctx.inner_fn_context.is_none() { + ctx.inner_fn_context = Some(instr.id); + } + + visit_inner_function_blocks(inner_func_id, ctx, env); + + ctx.inner_fn_context = prev_inner; + } + _ => { + handle_instruction(instr, ctx, env); + } + } + } + + // Terminal operands + if !ctx.is_deferred_dependency_terminal(*block_id) { + let terminal_ops = visitors::each_terminal_operand(&block.terminal); + for op in &terminal_ops { + ctx.visit_operand(op, env); + } + } + } +} diff --git a/crates/react_compiler_lowering/Cargo.toml b/crates/react_compiler_lowering/Cargo.toml new file mode 100644 index 000000000000..c4dff0d79284 --- /dev/null +++ b/crates/react_compiler_lowering/Cargo.toml @@ -0,0 +1,14 @@ +[package] +description = "Vendored React Compiler lowering from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_lowering" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_ast = { path = "../react_compiler_ast" } +react_compiler_hir = { path = "../react_compiler_hir" } +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +indexmap = { workspace = true } +serde_json = { workspace = true } diff --git a/crates/react_compiler_lowering/src/build_hir.rs b/crates/react_compiler_lowering/src/build_hir.rs new file mode 100644 index 000000000000..4d8dd382ed4b --- /dev/null +++ b/crates/react_compiler_lowering/src/build_hir.rs @@ -0,0 +1,6395 @@ +use std::collections::HashSet; + +use indexmap::{IndexMap, IndexSet}; +use react_compiler_ast::scope::{BindingId, ScopeInfo, ScopeKind}; +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, CompilerErrorDetail, ErrorCategory, +}; +use react_compiler_hir::{environment::Environment, *}; + +use crate::{ + find_context_identifiers::find_context_identifiers, + hir_builder::HirBuilder, + identifier_loc_index::{build_identifier_loc_index, IdentifierLocIndex}, + FunctionNode, +}; + +// ============================================================================= +// Source location conversion +// ============================================================================= + +/// Convert an AST SourceLocation to an HIR SourceLocation. +fn convert_loc(loc: &react_compiler_ast::common::SourceLocation) -> SourceLocation { + SourceLocation { + start: Position { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: Position { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + } +} + +/// Convert an optional AST SourceLocation to an optional HIR SourceLocation. +fn convert_opt_loc( + loc: &Option, +) -> Option { + loc.as_ref().map(convert_loc) +} + +fn pattern_like_loc( + pattern: &react_compiler_ast::patterns::PatternLike, +) -> Option { + use react_compiler_ast::patterns::PatternLike; + match pattern { + PatternLike::Identifier(id) => id.base.loc.clone(), + PatternLike::ObjectPattern(p) => p.base.loc.clone(), + PatternLike::ArrayPattern(p) => p.base.loc.clone(), + PatternLike::AssignmentPattern(p) => p.base.loc.clone(), + PatternLike::RestElement(p) => p.base.loc.clone(), + PatternLike::MemberExpression(p) => p.base.loc.clone(), + } +} + +/// Extract the HIR SourceLocation from an Expression AST node. +fn expression_loc(expr: &react_compiler_ast::expressions::Expression) -> Option { + use react_compiler_ast::expressions::Expression; + let loc = match expr { + Expression::Identifier(e) => e.base.loc.clone(), + Expression::StringLiteral(e) => e.base.loc.clone(), + Expression::NumericLiteral(e) => e.base.loc.clone(), + Expression::BooleanLiteral(e) => e.base.loc.clone(), + Expression::NullLiteral(e) => e.base.loc.clone(), + Expression::BigIntLiteral(e) => e.base.loc.clone(), + Expression::RegExpLiteral(e) => e.base.loc.clone(), + Expression::CallExpression(e) => e.base.loc.clone(), + Expression::MemberExpression(e) => e.base.loc.clone(), + Expression::OptionalCallExpression(e) => e.base.loc.clone(), + Expression::OptionalMemberExpression(e) => e.base.loc.clone(), + Expression::BinaryExpression(e) => e.base.loc.clone(), + Expression::LogicalExpression(e) => e.base.loc.clone(), + Expression::UnaryExpression(e) => e.base.loc.clone(), + Expression::UpdateExpression(e) => e.base.loc.clone(), + Expression::ConditionalExpression(e) => e.base.loc.clone(), + Expression::AssignmentExpression(e) => e.base.loc.clone(), + Expression::SequenceExpression(e) => e.base.loc.clone(), + Expression::ArrowFunctionExpression(e) => e.base.loc.clone(), + Expression::FunctionExpression(e) => e.base.loc.clone(), + Expression::ObjectExpression(e) => e.base.loc.clone(), + Expression::ArrayExpression(e) => e.base.loc.clone(), + Expression::NewExpression(e) => e.base.loc.clone(), + Expression::TemplateLiteral(e) => e.base.loc.clone(), + Expression::TaggedTemplateExpression(e) => e.base.loc.clone(), + Expression::AwaitExpression(e) => e.base.loc.clone(), + Expression::YieldExpression(e) => e.base.loc.clone(), + Expression::SpreadElement(e) => e.base.loc.clone(), + Expression::MetaProperty(e) => e.base.loc.clone(), + Expression::ClassExpression(e) => e.base.loc.clone(), + Expression::PrivateName(e) => e.base.loc.clone(), + Expression::Super(e) => e.base.loc.clone(), + Expression::Import(e) => e.base.loc.clone(), + Expression::ThisExpression(e) => e.base.loc.clone(), + Expression::ParenthesizedExpression(e) => e.base.loc.clone(), + Expression::JSXElement(e) => e.base.loc.clone(), + Expression::JSXFragment(e) => e.base.loc.clone(), + Expression::AssignmentPattern(e) => e.base.loc.clone(), + Expression::TSAsExpression(e) => e.base.loc.clone(), + Expression::TSSatisfiesExpression(e) => e.base.loc.clone(), + Expression::TSNonNullExpression(e) => e.base.loc.clone(), + Expression::TSTypeAssertion(e) => e.base.loc.clone(), + Expression::TSInstantiationExpression(e) => e.base.loc.clone(), + Expression::TypeCastExpression(e) => e.base.loc.clone(), + }; + convert_opt_loc(&loc) +} + +/// Get the Babel-style type name of an Expression node (e.g. "Identifier", +/// "NumericLiteral"). +fn expression_type_name(expr: &react_compiler_ast::expressions::Expression) -> &'static str { + use react_compiler_ast::expressions::Expression; + match expr { + Expression::Identifier(_) => "Identifier", + Expression::StringLiteral(_) => "StringLiteral", + Expression::NumericLiteral(_) => "NumericLiteral", + Expression::BooleanLiteral(_) => "BooleanLiteral", + Expression::NullLiteral(_) => "NullLiteral", + Expression::BigIntLiteral(_) => "BigIntLiteral", + Expression::RegExpLiteral(_) => "RegExpLiteral", + Expression::CallExpression(_) => "CallExpression", + Expression::MemberExpression(_) => "MemberExpression", + Expression::OptionalCallExpression(_) => "OptionalCallExpression", + Expression::OptionalMemberExpression(_) => "OptionalMemberExpression", + Expression::BinaryExpression(_) => "BinaryExpression", + Expression::LogicalExpression(_) => "LogicalExpression", + Expression::UnaryExpression(_) => "UnaryExpression", + Expression::UpdateExpression(_) => "UpdateExpression", + Expression::ConditionalExpression(_) => "ConditionalExpression", + Expression::AssignmentExpression(_) => "AssignmentExpression", + Expression::SequenceExpression(_) => "SequenceExpression", + Expression::ArrowFunctionExpression(_) => "ArrowFunctionExpression", + Expression::FunctionExpression(_) => "FunctionExpression", + Expression::ObjectExpression(_) => "ObjectExpression", + Expression::ArrayExpression(_) => "ArrayExpression", + Expression::NewExpression(_) => "NewExpression", + Expression::TemplateLiteral(_) => "TemplateLiteral", + Expression::TaggedTemplateExpression(_) => "TaggedTemplateExpression", + Expression::AwaitExpression(_) => "AwaitExpression", + Expression::YieldExpression(_) => "YieldExpression", + Expression::SpreadElement(_) => "SpreadElement", + Expression::MetaProperty(_) => "MetaProperty", + Expression::ClassExpression(_) => "ClassExpression", + Expression::PrivateName(_) => "PrivateName", + Expression::Super(_) => "Super", + Expression::Import(_) => "Import", + Expression::ThisExpression(_) => "ThisExpression", + Expression::ParenthesizedExpression(_) => "ParenthesizedExpression", + Expression::JSXElement(_) => "JSXElement", + Expression::JSXFragment(_) => "JSXFragment", + Expression::AssignmentPattern(_) => "AssignmentPattern", + Expression::TSAsExpression(_) => "TSAsExpression", + Expression::TSSatisfiesExpression(_) => "TSSatisfiesExpression", + Expression::TSNonNullExpression(_) => "TSNonNullExpression", + Expression::TSTypeAssertion(_) => "TSTypeAssertion", + Expression::TSInstantiationExpression(_) => "TSInstantiationExpression", + Expression::TypeCastExpression(_) => "TypeCastExpression", + } +} + +/// Extract the type annotation name from an identifier's typeAnnotation field. +/// The Babel AST stores type annotations as: +/// { "type": "TSTypeAnnotation", "typeAnnotation": { "type": "TSTypeReference", +/// ... } } or { "type": "TypeAnnotation", "typeAnnotation": { "type": +/// "GenericTypeAnnotation", ... } } We extract the inner typeAnnotation's +/// `type` field name. +fn extract_type_annotation_name( + type_annotation: &Option>, +) -> Option { + let val = type_annotation.as_ref()?; + // Navigate: typeAnnotation.typeAnnotation.type + let inner = val.get("typeAnnotation")?; + let type_name = inner.get("type")?.as_str()?; + Some(type_name.to_string()) +} + +// ============================================================================= +// Helper functions +// ============================================================================= + +fn build_temporary_place(builder: &mut HirBuilder, loc: Option) -> Place { + let id = builder.make_temporary(loc.clone()); + Place { + identifier: id, + reactive: false, + effect: Effect::Unknown, + loc, + } +} + +/// Promote a temporary identifier to a named identifier (for destructuring). +/// Corresponds to TS `promoteTemporary(identifier)`. +fn promote_temporary(builder: &mut HirBuilder, identifier_id: IdentifierId) { + let env = builder.environment_mut(); + let decl_id = env.identifiers[identifier_id.0 as usize].declaration_id; + env.identifiers[identifier_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); +} + +fn lower_value_to_temporary( + builder: &mut HirBuilder, + value: InstructionValue, +) -> Result { + // Optimization: if loading an unnamed temporary, skip creating a new + // instruction + if let InstructionValue::LoadLocal { ref place, .. } = value { + let ident = &builder.environment().identifiers[place.identifier.0 as usize]; + if ident.name.is_none() { + return Ok(place.clone()); + } + } + let loc = value.loc().cloned(); + let place = build_temporary_place(builder, loc.clone()); + builder.push(Instruction { + id: EvaluationOrder(0), + lvalue: place.clone(), + value, + loc, + effects: None, + }); + Ok(place) +} + +fn lower_expression_to_temporary( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::Expression, +) -> Result { + let value = lower_expression(builder, expr)?; + Ok(lower_value_to_temporary(builder, value)?) +} + +// ============================================================================= +// Operator conversion +// ============================================================================= + +fn convert_binary_operator(op: &react_compiler_ast::operators::BinaryOperator) -> BinaryOperator { + use react_compiler_ast::operators::BinaryOperator as AstOp; + match op { + AstOp::Add => BinaryOperator::Add, + AstOp::Sub => BinaryOperator::Subtract, + AstOp::Mul => BinaryOperator::Multiply, + AstOp::Div => BinaryOperator::Divide, + AstOp::Rem => BinaryOperator::Modulo, + AstOp::Exp => BinaryOperator::Exponent, + AstOp::Eq => BinaryOperator::Equal, + AstOp::StrictEq => BinaryOperator::StrictEqual, + AstOp::Neq => BinaryOperator::NotEqual, + AstOp::StrictNeq => BinaryOperator::StrictNotEqual, + AstOp::Lt => BinaryOperator::LessThan, + AstOp::Lte => BinaryOperator::LessEqual, + AstOp::Gt => BinaryOperator::GreaterThan, + AstOp::Gte => BinaryOperator::GreaterEqual, + AstOp::Shl => BinaryOperator::ShiftLeft, + AstOp::Shr => BinaryOperator::ShiftRight, + AstOp::UShr => BinaryOperator::UnsignedShiftRight, + AstOp::BitOr => BinaryOperator::BitwiseOr, + AstOp::BitXor => BinaryOperator::BitwiseXor, + AstOp::BitAnd => BinaryOperator::BitwiseAnd, + AstOp::In => BinaryOperator::In, + AstOp::Instanceof => BinaryOperator::InstanceOf, + AstOp::Pipeline => { + unreachable!("Pipeline operator is checked before calling convert_binary_operator") + } + } +} + +fn convert_unary_operator(op: &react_compiler_ast::operators::UnaryOperator) -> UnaryOperator { + use react_compiler_ast::operators::UnaryOperator as AstOp; + match op { + AstOp::Neg => UnaryOperator::Minus, + AstOp::Plus => UnaryOperator::Plus, + AstOp::Not => UnaryOperator::Not, + AstOp::BitNot => UnaryOperator::BitwiseNot, + AstOp::TypeOf => UnaryOperator::TypeOf, + AstOp::Void => UnaryOperator::Void, + AstOp::Delete | AstOp::Throw => unreachable!("delete/throw handled separately"), + } +} + +// ============================================================================= +// lower_identifier +// ============================================================================= + +/// Resolve an identifier to a Place. +/// +/// For local/context identifiers, returns a Place referencing the binding's +/// identifier. For globals/imports, emits a LoadGlobal instruction and returns +/// the temporary Place. +fn lower_identifier( + builder: &mut HirBuilder, + name: &str, + start: u32, + loc: Option, +) -> Result { + let binding = builder.resolve_identifier(name, start, loc.clone())?; + match binding { + VariableBinding::Identifier { identifier, .. } => Ok(Place { + identifier, + effect: Effect::Unknown, + reactive: false, + loc, + }), + _ => { + if let VariableBinding::Global { ref name } = binding { + if name == "eval" { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::UnsupportedSyntax, + reason: "The 'eval' function is not supported".to_string(), + description: Some( + "Eval is an anti-pattern in JavaScript, and the code executed cannot \ + be evaluated by React Compiler" + .to_string(), + ), + loc: loc.clone(), + suggestions: None, + })?; + } + } + let non_local_binding = match binding { + VariableBinding::Global { name } => NonLocalBinding::Global { name }, + VariableBinding::ImportDefault { name, module } => { + NonLocalBinding::ImportDefault { name, module } + } + VariableBinding::ImportSpecifier { + name, + module, + imported, + } => NonLocalBinding::ImportSpecifier { + name, + module, + imported, + }, + VariableBinding::ImportNamespace { name, module } => { + NonLocalBinding::ImportNamespace { name, module } + } + VariableBinding::ModuleLocal { name } => NonLocalBinding::ModuleLocal { name }, + VariableBinding::Identifier { .. } => unreachable!(), + }; + let instr_value = InstructionValue::LoadGlobal { + binding: non_local_binding, + loc: loc.clone(), + }; + Ok(lower_value_to_temporary(builder, instr_value)?) + } + } +} + +// ============================================================================= +// lower_arguments +// ============================================================================= + +fn lower_arguments( + builder: &mut HirBuilder, + args: &[react_compiler_ast::expressions::Expression], +) -> Result, CompilerError> { + use react_compiler_ast::expressions::Expression; + let mut result = Vec::new(); + for arg in args { + match arg { + Expression::SpreadElement(spread) => { + let place = lower_expression_to_temporary(builder, &spread.argument)?; + result.push(PlaceOrSpread::Spread(SpreadPattern { place })); + } + _ => { + let place = lower_expression_to_temporary(builder, arg)?; + result.push(PlaceOrSpread::Place(place)); + } + } + } + Ok(result) +} + +fn convert_update_operator(op: &react_compiler_ast::operators::UpdateOperator) -> UpdateOperator { + match op { + react_compiler_ast::operators::UpdateOperator::Increment => UpdateOperator::Increment, + react_compiler_ast::operators::UpdateOperator::Decrement => UpdateOperator::Decrement, + } +} + +// ============================================================================= +// lower_member_expression +// ============================================================================= + +enum MemberProperty { + Literal(PropertyLiteral), + Computed(Place), +} + +struct LoweredMemberExpression { + object: Place, + property: MemberProperty, + value: InstructionValue, +} + +fn lower_member_expression( + builder: &mut HirBuilder, + member: &react_compiler_ast::expressions::MemberExpression, +) -> Result { + Ok(lower_member_expression_impl(builder, member, None)?) +} + +fn lower_member_expression_with_object( + builder: &mut HirBuilder, + member: &react_compiler_ast::expressions::OptionalMemberExpression, + lowered_object: Place, +) -> Result { + // OptionalMemberExpression has the same shape as MemberExpression for property + // access + use react_compiler_ast::expressions::Expression; + let loc = convert_opt_loc(&member.base.loc); + let object = lowered_object; + + if !member.computed { + let prop_literal = match member.property.as_ref() { + Expression::Identifier(id) => PropertyLiteral::String(id.name.clone()), + Expression::NumericLiteral(lit) => PropertyLiteral::Number(FloatValue::new(lit.value)), + _ => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!( + "(BuildHIR::lowerMemberExpression) Handle {:?} property", + member.property + ), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(LoweredMemberExpression { + object, + property: MemberProperty::Literal(PropertyLiteral::String("".to_string())), + value: InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + }); + } + }; + let value = InstructionValue::PropertyLoad { + object: object.clone(), + property: prop_literal.clone(), + loc, + }; + Ok(LoweredMemberExpression { + object, + property: MemberProperty::Literal(prop_literal), + value, + }) + } else { + if let Expression::NumericLiteral(lit) = member.property.as_ref() { + let prop_literal = PropertyLiteral::Number(FloatValue::new(lit.value)); + let value = InstructionValue::PropertyLoad { + object: object.clone(), + property: prop_literal.clone(), + loc, + }; + return Ok(LoweredMemberExpression { + object, + property: MemberProperty::Literal(prop_literal), + value, + }); + } + let property = lower_expression_to_temporary(builder, &member.property)?; + let value = InstructionValue::ComputedLoad { + object: object.clone(), + property: property.clone(), + loc, + }; + Ok(LoweredMemberExpression { + object, + property: MemberProperty::Computed(property), + value, + }) + } +} + +fn lower_member_expression_impl( + builder: &mut HirBuilder, + member: &react_compiler_ast::expressions::MemberExpression, + lowered_object: Option, +) -> Result { + use react_compiler_ast::expressions::Expression; + let loc = convert_opt_loc(&member.base.loc); + let object = match lowered_object { + Some(obj) => obj, + None => lower_expression_to_temporary(builder, &member.object)?, + }; + + if !member.computed { + // Non-computed: property must be an identifier or numeric literal + let prop_literal = match member.property.as_ref() { + Expression::Identifier(id) => PropertyLiteral::String(id.name.clone()), + Expression::NumericLiteral(lit) => PropertyLiteral::Number(FloatValue::new(lit.value)), + _ => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!( + "(BuildHIR::lowerMemberExpression) Handle {:?} property", + member.property + ), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(LoweredMemberExpression { + object, + property: MemberProperty::Literal(PropertyLiteral::String("".to_string())), + value: InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + }); + } + }; + let value = InstructionValue::PropertyLoad { + object: object.clone(), + property: prop_literal.clone(), + loc, + }; + Ok(LoweredMemberExpression { + object, + property: MemberProperty::Literal(prop_literal), + value, + }) + } else { + // Computed: check for numeric literal first (treated as PropertyLoad in TS) + if let Expression::NumericLiteral(lit) = member.property.as_ref() { + let prop_literal = PropertyLiteral::Number(FloatValue::new(lit.value)); + let value = InstructionValue::PropertyLoad { + object: object.clone(), + property: prop_literal.clone(), + loc, + }; + return Ok(LoweredMemberExpression { + object, + property: MemberProperty::Literal(prop_literal), + value, + }); + } + // Otherwise lower property to temporary for ComputedLoad + let property = lower_expression_to_temporary(builder, &member.property)?; + let value = InstructionValue::ComputedLoad { + object: object.clone(), + property: property.clone(), + loc, + }; + Ok(LoweredMemberExpression { + object, + property: MemberProperty::Computed(property), + value, + }) + } +} + +// ============================================================================= +// lower_expression +// ============================================================================= + +fn lower_expression( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::Expression, +) -> Result { + use react_compiler_ast::expressions::Expression; + + match expr { + Expression::Identifier(ident) => { + let loc = convert_opt_loc(&ident.base.loc); + let start = ident.base.start.unwrap_or(0); + let place = lower_identifier(builder, &ident.name, start, loc.clone())?; + // Determine LoadLocal vs LoadContext based on context identifier check + if builder.is_context_identifier(&ident.name, start) { + Ok(InstructionValue::LoadContext { place, loc }) + } else { + Ok(InstructionValue::LoadLocal { place, loc }) + } + } + Expression::NullLiteral(lit) => { + let loc = convert_opt_loc(&lit.base.loc); + Ok(InstructionValue::Primitive { + value: PrimitiveValue::Null, + loc, + }) + } + Expression::BooleanLiteral(lit) => { + let loc = convert_opt_loc(&lit.base.loc); + Ok(InstructionValue::Primitive { + value: PrimitiveValue::Boolean(lit.value), + loc, + }) + } + Expression::NumericLiteral(lit) => { + let loc = convert_opt_loc(&lit.base.loc); + Ok(InstructionValue::Primitive { + value: PrimitiveValue::Number(FloatValue::new(lit.value)), + loc, + }) + } + Expression::StringLiteral(lit) => { + let loc = convert_opt_loc(&lit.base.loc); + Ok(InstructionValue::Primitive { + value: PrimitiveValue::String(lit.value.clone()), + loc, + }) + } + Expression::BinaryExpression(bin) => { + let loc = convert_opt_loc(&bin.base.loc); + // Check for pipeline operator before lowering operands + if matches!( + bin.operator, + react_compiler_ast::operators::BinaryOperator::Pipeline + ) { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Pipe operator not supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }); + } + let left = lower_expression_to_temporary(builder, &bin.left)?; + let right = lower_expression_to_temporary(builder, &bin.right)?; + let operator = convert_binary_operator(&bin.operator); + Ok(InstructionValue::BinaryExpression { + operator, + left, + right, + loc, + }) + } + Expression::UnaryExpression(unary) => { + let loc = convert_opt_loc(&unary.base.loc); + match &unary.operator { + react_compiler_ast::operators::UnaryOperator::Delete => { + // Delete can be on member expressions or identifiers + let loc = convert_opt_loc(&unary.base.loc); + match &*unary.argument { + Expression::MemberExpression(member) => { + let object = lower_expression_to_temporary(builder, &member.object)?; + if !member.computed { + match &*member.property { + Expression::Identifier(prop_id) => { + Ok(InstructionValue::PropertyDelete { + object, + property: PropertyLiteral::String(prop_id.name.clone()), + loc, + }) + } + _ => { + builder.record_error(CompilerErrorDetail { + reason: "Unsupported delete target".to_string(), + category: ErrorCategory::Todo, + loc: loc.clone(), + description: None, + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + } + } else { + let property = + lower_expression_to_temporary(builder, &member.property)?; + Ok(InstructionValue::ComputedDelete { + object, + property, + loc, + }) + } + } + _ => { + // delete on non-member expression (e.g., delete x) - not commonly + // supported + builder.record_error(CompilerErrorDetail { + reason: "Unsupported delete target".to_string(), + category: ErrorCategory::Todo, + loc: loc.clone(), + description: None, + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + } + } + react_compiler_ast::operators::UnaryOperator::Throw => { + // throw as unary operator (Babel-specific) + let loc = convert_opt_loc(&unary.base.loc); + builder.record_error(CompilerErrorDetail { + reason: "throw expressions are not supported".to_string(), + category: ErrorCategory::Todo, + loc: loc.clone(), + description: None, + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + op => { + let value = lower_expression_to_temporary(builder, &unary.argument)?; + let operator = convert_unary_operator(op); + Ok(InstructionValue::UnaryExpression { + operator, + value, + loc, + }) + } + } + } + Expression::CallExpression(call) => { + let loc = convert_opt_loc(&call.base.loc); + // Check if callee is a MemberExpression => MethodCall + if let Expression::MemberExpression(member) = call.callee.as_ref() { + let lowered = lower_member_expression(builder, member)?; + let property = lower_value_to_temporary(builder, lowered.value)?; + let args = lower_arguments(builder, &call.arguments)?; + Ok(InstructionValue::MethodCall { + receiver: lowered.object, + property, + args, + loc, + }) + } else { + let callee = lower_expression_to_temporary(builder, &call.callee)?; + let args = lower_arguments(builder, &call.arguments)?; + Ok(InstructionValue::CallExpression { callee, args, loc }) + } + } + Expression::MemberExpression(member) => { + let lowered = lower_member_expression(builder, member)?; + Ok(lowered.value) + } + Expression::OptionalCallExpression(opt_call) => { + Ok(lower_optional_call_expression(builder, opt_call)?) + } + Expression::OptionalMemberExpression(opt_member) => { + Ok(lower_optional_member_expression(builder, opt_member)?) + } + Expression::LogicalExpression(expr) => { + let loc = convert_opt_loc(&expr.base.loc); + let continuation_block = builder.reserve(builder.current_block_kind()); + let continuation_id = continuation_block.id; + let test_block = builder.reserve(BlockKind::Value); + let test_block_id = test_block.id; + let place = build_temporary_place(builder, loc.clone()); + let left_loc = expression_loc(&expr.left); + let left_place = build_temporary_place(builder, left_loc); + + // Block for short-circuit case: store left value as result, goto continuation + let consequent_block = builder.try_enter(BlockKind::Value, |builder, _block_id| { + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: left_place.clone(), + type_annotation: None, + loc: left_place.loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: left_place.loc.clone(), + }) + }); + + // Block for evaluating right side + let alternate_block = builder.try_enter(BlockKind::Value, |builder, _block_id| { + let right = lower_expression_to_temporary(builder, &expr.right)?; + let right_loc = right.loc.clone(); + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: right, + type_annotation: None, + loc: right_loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: right_loc, + }) + }); + + let hir_op = match expr.operator { + react_compiler_ast::operators::LogicalOperator::And => LogicalOperator::And, + react_compiler_ast::operators::LogicalOperator::Or => LogicalOperator::Or, + react_compiler_ast::operators::LogicalOperator::NullishCoalescing => { + LogicalOperator::NullishCoalescing + } + }; + + builder.terminate_with_continuation( + Terminal::Logical { + operator: hir_op, + test: test_block_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + test_block, + ); + + // Now in test block: lower left expression, copy to left_place + let left_value = lower_expression_to_temporary(builder, &expr.left)?; + builder.push(Instruction { + id: EvaluationOrder(0), + lvalue: left_place.clone(), + value: InstructionValue::LoadLocal { + place: left_value, + loc: loc.clone(), + }, + effects: None, + loc: loc.clone(), + }); + + builder.terminate_with_continuation( + Terminal::Branch { + test: left_place, + consequent: consequent_block?, + alternate: alternate_block?, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + + Ok(InstructionValue::LoadLocal { + place: place.clone(), + loc: place.loc.clone(), + }) + } + Expression::UpdateExpression(update) => { + let loc = convert_opt_loc(&update.base.loc); + match update.argument.as_ref() { + Expression::MemberExpression(member) => { + let binary_op = match &update.operator { + react_compiler_ast::operators::UpdateOperator::Increment => { + BinaryOperator::Add + } + react_compiler_ast::operators::UpdateOperator::Decrement => { + BinaryOperator::Subtract + } + }; + // Use the member expression's loc (not the update expression's) + // to match TS behavior where the inner operations use leftExpr.node.loc + let member_loc = convert_opt_loc(&member.base.loc); + let lowered = lower_member_expression(builder, member)?; + let object = lowered.object; + let lowered_property = lowered.property; + let prev_value = lower_value_to_temporary(builder, lowered.value)?; + + let one = lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::Number(FloatValue::new(1.0)), + loc: None, + }, + )?; + let updated = lower_value_to_temporary( + builder, + InstructionValue::BinaryExpression { + operator: binary_op, + left: prev_value.clone(), + right: one, + loc: member_loc.clone(), + }, + )?; + + // Store back using the property from the lowered member expression. + // For prefix, the result is the PropertyStore/ComputedStore lvalue + // (matching TS which uses newValuePlace). For postfix, it's prev_value. + let new_value_place = match lowered_property { + MemberProperty::Literal(prop_literal) => lower_value_to_temporary( + builder, + InstructionValue::PropertyStore { + object, + property: prop_literal, + value: updated.clone(), + loc: member_loc, + }, + )?, + MemberProperty::Computed(prop_place) => lower_value_to_temporary( + builder, + InstructionValue::ComputedStore { + object, + property: prop_place, + value: updated.clone(), + loc: member_loc, + }, + )?, + }; + + // Return previous for postfix, newValuePlace for prefix + let result_place = if update.prefix { + new_value_place + } else { + prev_value + }; + Ok(InstructionValue::LoadLocal { + place: result_place.clone(), + loc: result_place.loc.clone(), + }) + } + Expression::Identifier(ident) => { + let start = ident.base.start.unwrap_or(0); + if builder.is_context_identifier(&ident.name, start) { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Handle UpdateExpression to \ + variables captured within lambdas." + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: Some("UpdateExpression".to_string()), + original_node: None, + loc, + }); + } + + let ident_loc = convert_opt_loc(&ident.base.loc); + let binding = + builder.resolve_identifier(&ident.name, start, ident_loc.clone())?; + match &binding { + VariableBinding::Global { .. } => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "UpdateExpression where argument is a global is not yet \ + supported" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: Some("UpdateExpression".to_string()), + original_node: None, + loc, + }); + } + _ => {} + } + let identifier = match binding { + VariableBinding::Identifier { identifier, .. } => identifier, + _ => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Support UpdateExpression \ + where argument is a global" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: Some("UpdateExpression".to_string()), + original_node: None, + loc, + }); + } + }; + let lvalue_place = Place { + identifier, + effect: Effect::Unknown, + reactive: false, + loc: ident_loc.clone(), + }; + + // Load the current value + let value = lower_identifier(builder, &ident.name, start, ident_loc)?; + + let operation = convert_update_operator(&update.operator); + + if update.prefix { + Ok(InstructionValue::PrefixUpdate { + lvalue: lvalue_place, + operation, + value, + loc, + }) + } else { + Ok(InstructionValue::PostfixUpdate { + lvalue: lvalue_place, + operation, + value, + loc, + }) + } + } + _ => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!("UpdateExpression with unsupported argument type"), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + } + } + Expression::ConditionalExpression(expr) => { + let loc = convert_opt_loc(&expr.base.loc); + let continuation_block = builder.reserve(builder.current_block_kind()); + let continuation_id = continuation_block.id; + let test_block = builder.reserve(BlockKind::Value); + let test_block_id = test_block.id; + let place = build_temporary_place(builder, loc.clone()); + + // Block for the consequent (test is truthy) + let consequent_ast_loc = expression_loc(&expr.consequent); + let consequent_block = builder.try_enter(BlockKind::Value, |builder, _block_id| { + let consequent = lower_expression_to_temporary(builder, &expr.consequent)?; + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: consequent, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: consequent_ast_loc, + }) + }); + + // Block for the alternate (test is falsy) + let alternate_ast_loc = expression_loc(&expr.alternate); + let alternate_block = builder.try_enter(BlockKind::Value, |builder, _block_id| { + let alternate = lower_expression_to_temporary(builder, &expr.alternate)?; + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: alternate, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: alternate_ast_loc, + }) + }); + + builder.terminate_with_continuation( + Terminal::Ternary { + test: test_block_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + test_block, + ); + + // Now in test block: lower test expression + let test_place = lower_expression_to_temporary(builder, &expr.test)?; + builder.terminate_with_continuation( + Terminal::Branch { + test: test_place, + consequent: consequent_block?, + alternate: alternate_block?, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + + Ok(InstructionValue::LoadLocal { + place: place.clone(), + loc: place.loc.clone(), + }) + } + Expression::AssignmentExpression(expr) => { + use react_compiler_ast::operators::AssignmentOperator; + let loc = convert_opt_loc(&expr.base.loc); + + if matches!(expr.operator, AssignmentOperator::Assign) { + // Simple `=` assignment + match &*expr.left { + react_compiler_ast::patterns::PatternLike::Identifier(ident) => { + // Handle simple identifier assignment directly + let start = ident.base.start.unwrap_or(0); + let right = lower_expression_to_temporary(builder, &expr.right)?; + let ident_loc = convert_opt_loc(&ident.base.loc); + let binding = + builder.resolve_identifier(&ident.name, start, ident_loc.clone())?; + match binding { + VariableBinding::Identifier { + identifier, + binding_kind, + } => { + // Check for const reassignment + if binding_kind == BindingKind::Const { + builder.record_error(CompilerErrorDetail { + reason: "Cannot reassign a `const` variable".to_string(), + category: ErrorCategory::Syntax, + loc: ident_loc.clone(), + description: Some(format!( + "`{}` is declared as const", + &ident.name + )), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: Some("Identifier".to_string()), + original_node: None, + loc: ident_loc, + }); + } + let place = Place { + identifier, + reactive: false, + effect: Effect::Unknown, + loc: ident_loc, + }; + if builder.is_context_identifier(&ident.name, start) { + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreContext { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: place.clone(), + }, + value: right, + loc: place.loc.clone(), + }, + )?; + Ok(InstructionValue::LoadLocal { + place: temp.clone(), + loc: temp.loc.clone(), + }) + } else { + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: place.clone(), + }, + value: right, + type_annotation: None, + loc: place.loc.clone(), + }, + )?; + Ok(InstructionValue::LoadLocal { + place: temp.clone(), + loc: temp.loc.clone(), + }) + } + } + _ => { + // Global or import assignment + let name = ident.name.clone(); + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreGlobal { + name, + value: right, + loc: ident_loc, + }, + )?; + Ok(InstructionValue::LoadLocal { + place: temp.clone(), + loc: temp.loc.clone(), + }) + } + } + } + react_compiler_ast::patterns::PatternLike::MemberExpression(member) => { + // Member expression assignment: a.b = value or a[b] = value + let right = lower_expression_to_temporary(builder, &expr.right)?; + let left_loc = convert_opt_loc(&member.base.loc); + let object = lower_expression_to_temporary(builder, &member.object)?; + let temp = if !member.computed + || matches!( + &*member.property, + react_compiler_ast::expressions::Expression::NumericLiteral(_) + ) { + match &*member.property { + react_compiler_ast::expressions::Expression::Identifier( + prop_id, + ) => lower_value_to_temporary( + builder, + InstructionValue::PropertyStore { + object, + property: PropertyLiteral::String(prop_id.name.clone()), + value: right, + loc: left_loc, + }, + )?, + react_compiler_ast::expressions::Expression::NumericLiteral( + num, + ) => lower_value_to_temporary( + builder, + InstructionValue::PropertyStore { + object, + property: PropertyLiteral::Number(FloatValue::new( + num.value, + )), + value: right, + loc: left_loc, + }, + )?, + _ => { + let prop = + lower_expression_to_temporary(builder, &member.property)?; + lower_value_to_temporary( + builder, + InstructionValue::ComputedStore { + object, + property: prop, + value: right, + loc: left_loc, + }, + )? + } + } + } else { + let prop = lower_expression_to_temporary(builder, &member.property)?; + lower_value_to_temporary( + builder, + InstructionValue::ComputedStore { + object, + property: prop, + value: right, + loc: left_loc, + }, + )? + }; + Ok(InstructionValue::LoadLocal { + place: temp.clone(), + loc: temp.loc.clone(), + }) + } + _ => { + // Destructuring assignment + let right = lower_expression_to_temporary(builder, &expr.right)?; + let left_loc = pattern_like_hir_loc(&expr.left); + let result = lower_assignment( + builder, + left_loc, + InstructionKind::Reassign, + &expr.left, + right.clone(), + AssignmentStyle::Destructure, + )?; + match result { + Some(place) => Ok(InstructionValue::LoadLocal { + place: place.clone(), + loc: place.loc.clone(), + }), + None => Ok(InstructionValue::LoadLocal { place: right, loc }), + } + } + } + } else { + // Compound assignment operators + let binary_op = match expr.operator { + AssignmentOperator::AddAssign => Some(BinaryOperator::Add), + AssignmentOperator::SubAssign => Some(BinaryOperator::Subtract), + AssignmentOperator::MulAssign => Some(BinaryOperator::Multiply), + AssignmentOperator::DivAssign => Some(BinaryOperator::Divide), + AssignmentOperator::RemAssign => Some(BinaryOperator::Modulo), + AssignmentOperator::ExpAssign => Some(BinaryOperator::Exponent), + AssignmentOperator::ShlAssign => Some(BinaryOperator::ShiftLeft), + AssignmentOperator::ShrAssign => Some(BinaryOperator::ShiftRight), + AssignmentOperator::UShrAssign => Some(BinaryOperator::UnsignedShiftRight), + AssignmentOperator::BitOrAssign => Some(BinaryOperator::BitwiseOr), + AssignmentOperator::BitXorAssign => Some(BinaryOperator::BitwiseXor), + AssignmentOperator::BitAndAssign => Some(BinaryOperator::BitwiseAnd), + AssignmentOperator::OrAssign + | AssignmentOperator::AndAssign + | AssignmentOperator::NullishAssign => { + // Logical assignment operators (||=, &&=, ??=) - not yet supported + builder.record_error(CompilerErrorDetail { + reason: "Logical assignment operators (||=, &&=, ??=) are not yet \ + supported" + .to_string(), + category: ErrorCategory::Todo, + loc: loc.clone(), + description: None, + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }); + } + AssignmentOperator::Assign => unreachable!(), + }; + let binary_op = match binary_op { + Some(op) => op, + None => { + return Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }); + } + }; + + match &*expr.left { + react_compiler_ast::patterns::PatternLike::Identifier(ident) => { + let start = ident.base.start.unwrap_or(0); + let left_place = lower_expression_to_temporary( + builder, + &react_compiler_ast::expressions::Expression::Identifier(ident.clone()), + )?; + let right = lower_expression_to_temporary(builder, &expr.right)?; + let binary_place = lower_value_to_temporary( + builder, + InstructionValue::BinaryExpression { + operator: binary_op, + left: left_place, + right, + loc: loc.clone(), + }, + )?; + let ident_loc = convert_opt_loc(&ident.base.loc); + let binding = + builder.resolve_identifier(&ident.name, start, ident_loc.clone())?; + match binding { + VariableBinding::Identifier { identifier, .. } => { + let place = Place { + identifier, + reactive: false, + effect: Effect::Unknown, + loc: ident_loc, + }; + if builder.is_context_identifier(&ident.name, start) { + lower_value_to_temporary( + builder, + InstructionValue::StoreContext { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: place.clone(), + }, + value: binary_place, + loc: loc.clone(), + }, + )?; + Ok(InstructionValue::LoadContext { place, loc }) + } else { + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: place.clone(), + }, + value: binary_place, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(InstructionValue::LoadLocal { place, loc }) + } + } + _ => { + // Global assignment + let name = ident.name.clone(); + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreGlobal { + name, + value: binary_place, + loc: loc.clone(), + }, + )?; + Ok(InstructionValue::LoadLocal { + place: temp.clone(), + loc: temp.loc.clone(), + }) + } + } + } + react_compiler_ast::patterns::PatternLike::MemberExpression(member) => { + // a.b += right: read, compute, store + // Match TS behavior: return the PropertyStore/ComputedStore value + // directly (let the caller lower it to a temporary) + let member_loc = convert_opt_loc(&member.base.loc); + let lowered = lower_member_expression(builder, member)?; + let object = lowered.object; + let lowered_property = lowered.property; + let current_value = lower_value_to_temporary(builder, lowered.value)?; + let right = lower_expression_to_temporary(builder, &expr.right)?; + let result = lower_value_to_temporary( + builder, + InstructionValue::BinaryExpression { + operator: binary_op, + left: current_value, + right, + loc: member_loc.clone(), + }, + )?; + // Return the store instruction value directly (matching TS behavior) + match lowered_property { + MemberProperty::Literal(prop_literal) => { + Ok(InstructionValue::PropertyStore { + object, + property: prop_literal, + value: result, + loc: member_loc, + }) + } + MemberProperty::Computed(prop_place) => { + Ok(InstructionValue::ComputedStore { + object, + property: prop_place, + value: result, + loc: member_loc, + }) + } + } + } + _ => { + builder.record_error(CompilerErrorDetail { + reason: "Compound assignment to complex pattern is not yet supported" + .to_string(), + category: ErrorCategory::Todo, + loc: loc.clone(), + description: None, + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + } + } + } + Expression::SequenceExpression(seq) => { + let loc = convert_opt_loc(&seq.base.loc); + + if seq.expressions.is_empty() { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Syntax, + reason: "Expected sequence expression to have at least one expression" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }); + } + + let continuation_block = builder.reserve(builder.current_block_kind()); + let continuation_id = continuation_block.id; + let place = build_temporary_place(builder, loc.clone()); + + let sequence_block = builder.try_enter(BlockKind::Sequence, |builder, _block_id| { + let mut last: Option = None; + for item in &seq.expressions { + last = Some(lower_expression_to_temporary(builder, item)?); + } + if let Some(last) = last { + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: last, + type_annotation: None, + loc: loc.clone(), + }, + )?; + } + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + }); + + builder.terminate_with_continuation( + Terminal::Sequence { + block: sequence_block?, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + Ok(InstructionValue::LoadLocal { place, loc }) + } + Expression::ArrowFunctionExpression(_) => { + // The expression type is already known to be ArrowFunctionExpression at this + // point, so lower_function's non-function invariant cannot fail. + // Safe to unwrap. + Ok(lower_function_to_value( + builder, + expr, + FunctionExpressionType::ArrowFunctionExpression, + ) + .expect("lower_function_to_value called with ArrowFunctionExpression")) + } + Expression::FunctionExpression(_) => { + Ok( + lower_function_to_value(builder, expr, FunctionExpressionType::FunctionExpression) + .expect("lower_function_to_value called with FunctionExpression"), + ) + } + Expression::ObjectExpression(obj) => { + let loc = convert_opt_loc(&obj.base.loc); + let mut properties: Vec = Vec::new(); + for prop in &obj.properties { + match prop { + react_compiler_ast::expressions::ObjectExpressionProperty::ObjectProperty( + p, + ) => { + let key = lower_object_property_key(builder, &p.key, p.computed)?; + let key = match key { + Some(k) => k, + None => continue, + }; + let value = lower_expression_to_temporary(builder, &p.value)?; + properties.push(ObjectPropertyOrSpread::Property(ObjectProperty { + key, + property_type: ObjectPropertyType::Property, + place: value, + })); + } + react_compiler_ast::expressions::ObjectExpressionProperty::SpreadElement( + spread, + ) => { + let place = lower_expression_to_temporary(builder, &spread.argument)?; + properties.push(ObjectPropertyOrSpread::Spread(SpreadPattern { place })); + } + react_compiler_ast::expressions::ObjectExpressionProperty::ObjectMethod( + method, + ) => { + if let Some(prop) = lower_object_method(builder, method)? { + properties.push(ObjectPropertyOrSpread::Property(prop)); + } + } + } + } + Ok(InstructionValue::ObjectExpression { properties, loc }) + } + Expression::ArrayExpression(arr) => { + let loc = convert_opt_loc(&arr.base.loc); + let mut elements: Vec = Vec::new(); + for element in &arr.elements { + match element { + None => { + elements.push(ArrayElement::Hole); + } + Some(Expression::SpreadElement(spread)) => { + let place = lower_expression_to_temporary(builder, &spread.argument)?; + elements.push(ArrayElement::Spread(SpreadPattern { place })); + } + Some(expr) => { + let place = lower_expression_to_temporary(builder, expr)?; + elements.push(ArrayElement::Place(place)); + } + } + } + Ok(InstructionValue::ArrayExpression { elements, loc }) + } + Expression::NewExpression(new_expr) => { + let loc = convert_opt_loc(&new_expr.base.loc); + let callee = lower_expression_to_temporary(builder, &new_expr.callee)?; + let args = lower_arguments(builder, &new_expr.arguments)?; + Ok(InstructionValue::NewExpression { callee, args, loc }) + } + Expression::TemplateLiteral(tmpl) => { + let loc = convert_opt_loc(&tmpl.base.loc); + let subexprs: Vec = tmpl + .expressions + .iter() + .map(|e| lower_expression_to_temporary(builder, e)) + .collect::, _>>()?; + let quasis: Vec = tmpl + .quasis + .iter() + .map(|q| TemplateQuasi { + raw: q.value.raw.clone(), + cooked: q.value.cooked.clone(), + }) + .collect(); + Ok(InstructionValue::TemplateLiteral { + subexprs, + quasis, + loc, + }) + } + Expression::TaggedTemplateExpression(tagged) => { + let loc = convert_opt_loc(&tagged.base.loc); + if !tagged.quasi.expressions.is_empty() { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Handle tagged template with \ + interpolations" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: Some("TaggedTemplateExpression".to_string()), + original_node: None, + loc, + }); + } + assert!( + tagged.quasi.quasis.len() == 1, + "there should be only one quasi as we don't support interpolations yet" + ); + let quasi = &tagged.quasi.quasis[0]; + // Check if raw and cooked values differ (e.g., graphql tagged templates) + if quasi.value.raw != quasi.value.cooked.clone().unwrap_or_default() { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Handle tagged template where cooked \ + value is different from raw value" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(InstructionValue::UnsupportedNode { + node_type: Some("TaggedTemplateExpression".to_string()), + original_node: None, + loc, + }); + } + let value = TemplateQuasi { + raw: quasi.value.raw.clone(), + cooked: quasi.value.cooked.clone(), + }; + let tag = lower_expression_to_temporary(builder, &tagged.tag)?; + Ok(InstructionValue::TaggedTemplateExpression { tag, value, loc }) + } + Expression::AwaitExpression(await_expr) => { + let loc = convert_opt_loc(&await_expr.base.loc); + let value = lower_expression_to_temporary(builder, &await_expr.argument)?; + Ok(InstructionValue::Await { value, loc }) + } + Expression::YieldExpression(yld) => { + let loc = convert_opt_loc(&yld.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Handle YieldExpression expressions" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: Some("YieldExpression".to_string()), + original_node: None, + loc, + }) + } + Expression::SpreadElement(spread) => { + // SpreadElement should be handled by the parent context (array/object/call) + // If we reach here, just lower the argument expression + Ok(lower_expression(builder, &spread.argument)?) + } + Expression::MetaProperty(meta) => { + let loc = convert_opt_loc(&meta.base.loc); + if meta.meta.name == "import" && meta.property.name == "meta" { + Ok(InstructionValue::MetaProperty { + meta: meta.meta.name.clone(), + property: meta.property.name.clone(), + loc, + }) + } else { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerExpression) Handle MetaProperty expressions other \ + than import.meta" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: Some("MetaProperty".to_string()), + original_node: None, + loc, + }) + } + } + Expression::ClassExpression(cls) => { + let loc = convert_opt_loc(&cls.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "class expressions are not yet supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::PrivateName(pn) => { + let loc = convert_opt_loc(&pn.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "private names are not yet supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::Super(sup) => { + let loc = convert_opt_loc(&sup.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "super is not supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::Import(imp) => { + let loc = convert_opt_loc(&imp.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "dynamic import() is not yet supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::ThisExpression(this) => { + let loc = convert_opt_loc(&this.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "this is not supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::ParenthesizedExpression(paren) => { + Ok(lower_expression(builder, &paren.expression)?) + } + Expression::JSXElement(jsx_element) => { + let loc = convert_opt_loc(&jsx_element.base.loc); + let opening_loc = convert_opt_loc(&jsx_element.opening_element.base.loc); + let closing_loc = jsx_element + .closing_element + .as_ref() + .and_then(|c| convert_opt_loc(&c.base.loc)); + + // Lower the tag name + let tag = lower_jsx_element_name(builder, &jsx_element.opening_element.name)?; + + // Lower attributes (props) + let mut props: Vec = Vec::new(); + for attr_item in &jsx_element.opening_element.attributes { + use react_compiler_ast::jsx::{ + JSXAttributeItem, JSXAttributeName, JSXAttributeValue, + }; + match attr_item { + JSXAttributeItem::JSXSpreadAttribute(spread) => { + let argument = lower_expression_to_temporary(builder, &spread.argument)?; + props.push(JsxAttribute::SpreadAttribute { argument }); + } + JSXAttributeItem::JSXAttribute(attr) => { + // Get the attribute name + let prop_name = match &attr.name { + JSXAttributeName::JSXIdentifier(id) => { + let name = &id.name; + if name.contains(':') { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!( + "(BuildHIR::lowerExpression) Unexpected colon in \ + attribute name `{}`", + name + ), + description: None, + loc: convert_opt_loc(&id.base.loc), + suggestions: None, + })?; + } + name.clone() + } + JSXAttributeName::JSXNamespacedName(ns) => { + format!("{}:{}", ns.namespace.name, ns.name.name) + } + }; + + // Get the attribute value + let value = match &attr.value { + Some(JSXAttributeValue::StringLiteral(s)) => { + let str_loc = convert_opt_loc(&s.base.loc); + lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::String(s.value.clone()), + loc: str_loc, + }, + )? + } + Some(JSXAttributeValue::JSXExpressionContainer(container)) => { + use react_compiler_ast::jsx::JSXExpressionContainerExpr; + match &container.expression { + JSXExpressionContainerExpr::JSXEmptyExpression(_) => { + // Empty expression container - skip this attribute + continue; + } + JSXExpressionContainerExpr::Expression(expr) => { + lower_expression_to_temporary(builder, expr)? + } + } + } + Some(JSXAttributeValue::JSXElement(el)) => { + let val = lower_expression( + builder, + &react_compiler_ast::expressions::Expression::JSXElement( + el.clone(), + ), + )?; + lower_value_to_temporary(builder, val)? + } + Some(JSXAttributeValue::JSXFragment(frag)) => { + let val = lower_expression( + builder, + &react_compiler_ast::expressions::Expression::JSXFragment( + frag.clone(), + ), + )?; + lower_value_to_temporary(builder, val)? + } + None => { + // No value means boolean true (e.g.,
) + let attr_loc = convert_opt_loc(&attr.base.loc); + lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::Boolean(true), + loc: attr_loc, + }, + )? + } + }; + + props.push(JsxAttribute::Attribute { + name: prop_name, + place: value, + }); + } + } + } + + // Check if this is an fbt/fbs tag, which requires special whitespace handling + let is_fbt = matches!(&tag, JsxTag::Builtin(b) if b.name == "fbt" || b.name == "fbs"); + + // Check that fbt/fbs tags are module-level imports, not local bindings. + // Matches TS: CompilerError.invariant(tagIdentifier.kind !== 'Identifier', ...) + if is_fbt { + let tag_name = match &tag { + JsxTag::Builtin(b) => b.name.clone(), + _ => "fbt".to_string(), + }; + // Get the opening element's name identifier and check if it's a local binding + if let react_compiler_ast::jsx::JSXElementName::JSXIdentifier(jsx_id) = + &jsx_element.opening_element.name + { + let id_loc = convert_opt_loc(&jsx_id.base.loc); + // Check if fbt/fbs tag name resolves to a local binding. + // JSX identifiers may not be in our position-based reference map, + // so check if ANY binding with this name exists in the function scope. + let is_local_binding = builder.has_local_binding(&jsx_id.name); + if is_local_binding { + // Record as a Diagnostic (not ErrorDetail) to match TS behavior + // where CompilerError.invariant creates a CompilerDiagnostic. + let reason = format!("<{}> tags should be module-level imports", tag_name); + builder.record_diagnostic( + CompilerDiagnostic::new(ErrorCategory::Invariant, &reason, None) + .with_detail(CompilerDiagnosticDetail::Error { + loc: id_loc.clone(), + message: Some(reason.clone()), + identifier_name: None, + }), + ); + } + } + } + + // Check for duplicate fbt:enum, fbt:plural, fbt:pronoun tags + if is_fbt { + let tag_name = match &tag { + JsxTag::Builtin(b) => b.name.as_str(), + _ => "fbt", + }; + let mut enum_locs: Vec> = Vec::new(); + let mut plural_locs: Vec> = Vec::new(); + let mut pronoun_locs: Vec> = Vec::new(); + collect_fbt_sub_tags( + &jsx_element.children, + tag_name, + &mut enum_locs, + &mut plural_locs, + &mut pronoun_locs, + ); + + for (name, locations) in [ + ("enum", &enum_locs), + ("plural", &plural_locs), + ("pronoun", &pronoun_locs), + ] { + if locations.len() > 1 { + use react_compiler_diagnostics::CompilerDiagnosticDetail; + let details: Vec = locations + .iter() + .map(|loc| CompilerDiagnosticDetail::Error { + message: Some(format!( + "Multiple `<{}:{}>` tags found", + tag_name, name + )), + loc: loc.clone(), + identifier_name: None, + }) + .collect(); + let mut diag = react_compiler_diagnostics::CompilerDiagnostic::new( + ErrorCategory::Todo, + "Support duplicate fbt tags", + Some(format!( + "Support `<{}>` tags with multiple `<{}:{}>` values", + tag_name, tag_name, name + )), + ); + diag.details = details; + builder.environment_mut().record_diagnostic(diag); + } + } + } + + // Increment fbt counter before traversing into children, as whitespace + // in jsx text is handled differently for fbt subtrees. + if is_fbt { + builder.fbt_depth += 1; + } + + // Lower children + let children: Vec = jsx_element + .children + .iter() + .map(|child| lower_jsx_element(builder, child)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + if is_fbt { + builder.fbt_depth -= 1; + } + + Ok(InstructionValue::JsxExpression { + tag, + props, + children: if children.is_empty() { + None + } else { + Some(children) + }, + loc, + opening_loc, + closing_loc, + }) + } + Expression::JSXFragment(jsx_fragment) => { + let loc = convert_opt_loc(&jsx_fragment.base.loc); + + // Lower children + let children: Vec = jsx_fragment + .children + .iter() + .map(|child| lower_jsx_element(builder, child)) + .collect::, _>>()? + .into_iter() + .flatten() + .collect(); + + Ok(InstructionValue::JsxFragment { children, loc }) + } + Expression::AssignmentPattern(_) => { + let loc = convert_opt_loc(&match expr { + Expression::AssignmentPattern(p) => p.base.loc.clone(), + _ => unreachable!(), + }); + builder.record_error(CompilerErrorDetail { + reason: "AssignmentPattern in expression position is not supported".to_string(), + category: ErrorCategory::Todo, + loc: loc.clone(), + description: None, + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::TSAsExpression(ts) => { + let loc = convert_opt_loc(&ts.base.loc); + let value = lower_expression_to_temporary(builder, &ts.expression)?; + let type_annotation = &*ts.type_annotation; + let type_ = lower_type_annotation(type_annotation, builder); + let type_annotation_name = get_type_annotation_name(type_annotation); + Ok(InstructionValue::TypeCastExpression { + value, + type_, + type_annotation_name, + type_annotation_kind: Some("as".to_string()), + type_annotation: Some(ts.type_annotation.clone()), + loc, + }) + } + Expression::TSSatisfiesExpression(ts) => { + let loc = convert_opt_loc(&ts.base.loc); + let value = lower_expression_to_temporary(builder, &ts.expression)?; + let type_annotation = &*ts.type_annotation; + let type_ = lower_type_annotation(type_annotation, builder); + let type_annotation_name = get_type_annotation_name(type_annotation); + Ok(InstructionValue::TypeCastExpression { + value, + type_, + type_annotation_name, + type_annotation_kind: Some("satisfies".to_string()), + type_annotation: Some(ts.type_annotation.clone()), + loc, + }) + } + Expression::TSNonNullExpression(ts) => Ok(lower_expression(builder, &ts.expression)?), + Expression::TSTypeAssertion(ts) => { + let loc = convert_opt_loc(&ts.base.loc); + let value = lower_expression_to_temporary(builder, &ts.expression)?; + let type_annotation = &*ts.type_annotation; + let type_ = lower_type_annotation(type_annotation, builder); + let type_annotation_name = get_type_annotation_name(type_annotation); + Ok(InstructionValue::TypeCastExpression { + value, + type_, + type_annotation_name, + type_annotation_kind: Some("as".to_string()), + type_annotation: Some(ts.type_annotation.clone()), + loc, + }) + } + Expression::TSInstantiationExpression(ts) => Ok(lower_expression(builder, &ts.expression)?), + Expression::TypeCastExpression(tc) => { + let loc = convert_opt_loc(&tc.base.loc); + let value = lower_expression_to_temporary(builder, &tc.expression)?; + // Flow TypeCastExpression: typeAnnotation is a TypeAnnotation node wrapping the + // actual type + let inner_type = tc + .type_annotation + .get("typeAnnotation") + .unwrap_or(&*tc.type_annotation); + let type_ = lower_type_annotation(inner_type, builder); + let type_annotation_name = get_type_annotation_name(inner_type); + Ok(InstructionValue::TypeCastExpression { + value, + type_, + type_annotation_name, + type_annotation_kind: Some("cast".to_string()), + type_annotation: Some(tc.type_annotation.clone()), + loc, + }) + } + Expression::BigIntLiteral(big) => { + let loc = convert_opt_loc(&big.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "BigInt literals are not yet supported".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + Ok(InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }) + } + Expression::RegExpLiteral(re) => { + let loc = convert_opt_loc(&re.base.loc); + Ok(InstructionValue::RegExpLiteral { + pattern: re.pattern.clone(), + flags: re.flags.clone(), + loc, + }) + } + } +} + +// ============================================================================= +// Statement position helpers +// ============================================================================= + +fn statement_start(stmt: &react_compiler_ast::statements::Statement) -> Option { + use react_compiler_ast::statements::Statement; + match stmt { + Statement::BlockStatement(s) => s.base.start, + Statement::ReturnStatement(s) => s.base.start, + Statement::IfStatement(s) => s.base.start, + Statement::ForStatement(s) => s.base.start, + Statement::WhileStatement(s) => s.base.start, + Statement::DoWhileStatement(s) => s.base.start, + Statement::ForInStatement(s) => s.base.start, + Statement::ForOfStatement(s) => s.base.start, + Statement::SwitchStatement(s) => s.base.start, + Statement::ThrowStatement(s) => s.base.start, + Statement::TryStatement(s) => s.base.start, + Statement::BreakStatement(s) => s.base.start, + Statement::ContinueStatement(s) => s.base.start, + Statement::LabeledStatement(s) => s.base.start, + Statement::ExpressionStatement(s) => s.base.start, + Statement::EmptyStatement(s) => s.base.start, + Statement::DebuggerStatement(s) => s.base.start, + Statement::WithStatement(s) => s.base.start, + Statement::VariableDeclaration(s) => s.base.start, + Statement::FunctionDeclaration(s) => s.base.start, + Statement::ClassDeclaration(s) => s.base.start, + Statement::ImportDeclaration(s) => s.base.start, + Statement::ExportNamedDeclaration(s) => s.base.start, + Statement::ExportDefaultDeclaration(s) => s.base.start, + Statement::ExportAllDeclaration(s) => s.base.start, + Statement::TSTypeAliasDeclaration(s) => s.base.start, + Statement::TSInterfaceDeclaration(s) => s.base.start, + Statement::TSEnumDeclaration(s) => s.base.start, + Statement::TSModuleDeclaration(s) => s.base.start, + Statement::TSDeclareFunction(s) => s.base.start, + Statement::TypeAlias(s) => s.base.start, + Statement::OpaqueType(s) => s.base.start, + Statement::InterfaceDeclaration(s) => s.base.start, + Statement::DeclareVariable(s) => s.base.start, + Statement::DeclareFunction(s) => s.base.start, + Statement::DeclareClass(s) => s.base.start, + Statement::DeclareModule(s) => s.base.start, + Statement::DeclareModuleExports(s) => s.base.start, + Statement::DeclareExportDeclaration(s) => s.base.start, + Statement::DeclareExportAllDeclaration(s) => s.base.start, + Statement::DeclareInterface(s) => s.base.start, + Statement::DeclareTypeAlias(s) => s.base.start, + Statement::DeclareOpaqueType(s) => s.base.start, + Statement::EnumDeclaration(s) => s.base.start, + } +} + +fn statement_end(stmt: &react_compiler_ast::statements::Statement) -> Option { + use react_compiler_ast::statements::Statement; + match stmt { + Statement::BlockStatement(s) => s.base.end, + Statement::ReturnStatement(s) => s.base.end, + Statement::IfStatement(s) => s.base.end, + Statement::ForStatement(s) => s.base.end, + Statement::WhileStatement(s) => s.base.end, + Statement::DoWhileStatement(s) => s.base.end, + Statement::ForInStatement(s) => s.base.end, + Statement::ForOfStatement(s) => s.base.end, + Statement::SwitchStatement(s) => s.base.end, + Statement::ThrowStatement(s) => s.base.end, + Statement::TryStatement(s) => s.base.end, + Statement::BreakStatement(s) => s.base.end, + Statement::ContinueStatement(s) => s.base.end, + Statement::LabeledStatement(s) => s.base.end, + Statement::ExpressionStatement(s) => s.base.end, + Statement::EmptyStatement(s) => s.base.end, + Statement::DebuggerStatement(s) => s.base.end, + Statement::WithStatement(s) => s.base.end, + Statement::VariableDeclaration(s) => s.base.end, + Statement::FunctionDeclaration(s) => s.base.end, + Statement::ClassDeclaration(s) => s.base.end, + Statement::ImportDeclaration(s) => s.base.end, + Statement::ExportNamedDeclaration(s) => s.base.end, + Statement::ExportDefaultDeclaration(s) => s.base.end, + Statement::ExportAllDeclaration(s) => s.base.end, + Statement::TSTypeAliasDeclaration(s) => s.base.end, + Statement::TSInterfaceDeclaration(s) => s.base.end, + Statement::TSEnumDeclaration(s) => s.base.end, + Statement::TSModuleDeclaration(s) => s.base.end, + Statement::TSDeclareFunction(s) => s.base.end, + Statement::TypeAlias(s) => s.base.end, + Statement::OpaqueType(s) => s.base.end, + Statement::InterfaceDeclaration(s) => s.base.end, + Statement::DeclareVariable(s) => s.base.end, + Statement::DeclareFunction(s) => s.base.end, + Statement::DeclareClass(s) => s.base.end, + Statement::DeclareModule(s) => s.base.end, + Statement::DeclareModuleExports(s) => s.base.end, + Statement::DeclareExportDeclaration(s) => s.base.end, + Statement::DeclareExportAllDeclaration(s) => s.base.end, + Statement::DeclareInterface(s) => s.base.end, + Statement::DeclareTypeAlias(s) => s.base.end, + Statement::DeclareOpaqueType(s) => s.base.end, + Statement::EnumDeclaration(s) => s.base.end, + } +} + +/// Extract the HIR SourceLocation from a Statement AST node. +fn statement_loc(stmt: &react_compiler_ast::statements::Statement) -> Option { + use react_compiler_ast::statements::Statement; + let loc = match stmt { + Statement::BlockStatement(s) => s.base.loc.clone(), + Statement::ReturnStatement(s) => s.base.loc.clone(), + Statement::IfStatement(s) => s.base.loc.clone(), + Statement::ForStatement(s) => s.base.loc.clone(), + Statement::WhileStatement(s) => s.base.loc.clone(), + Statement::DoWhileStatement(s) => s.base.loc.clone(), + Statement::ForInStatement(s) => s.base.loc.clone(), + Statement::ForOfStatement(s) => s.base.loc.clone(), + Statement::SwitchStatement(s) => s.base.loc.clone(), + Statement::ThrowStatement(s) => s.base.loc.clone(), + Statement::TryStatement(s) => s.base.loc.clone(), + Statement::BreakStatement(s) => s.base.loc.clone(), + Statement::ContinueStatement(s) => s.base.loc.clone(), + Statement::LabeledStatement(s) => s.base.loc.clone(), + Statement::ExpressionStatement(s) => s.base.loc.clone(), + Statement::EmptyStatement(s) => s.base.loc.clone(), + Statement::DebuggerStatement(s) => s.base.loc.clone(), + Statement::WithStatement(s) => s.base.loc.clone(), + Statement::VariableDeclaration(s) => s.base.loc.clone(), + Statement::FunctionDeclaration(s) => s.base.loc.clone(), + Statement::ClassDeclaration(s) => s.base.loc.clone(), + Statement::ImportDeclaration(s) => s.base.loc.clone(), + Statement::ExportNamedDeclaration(s) => s.base.loc.clone(), + Statement::ExportDefaultDeclaration(s) => s.base.loc.clone(), + Statement::ExportAllDeclaration(s) => s.base.loc.clone(), + Statement::TSTypeAliasDeclaration(s) => s.base.loc.clone(), + Statement::TSInterfaceDeclaration(s) => s.base.loc.clone(), + Statement::TSEnumDeclaration(s) => s.base.loc.clone(), + Statement::TSModuleDeclaration(s) => s.base.loc.clone(), + Statement::TSDeclareFunction(s) => s.base.loc.clone(), + Statement::TypeAlias(s) => s.base.loc.clone(), + Statement::OpaqueType(s) => s.base.loc.clone(), + Statement::InterfaceDeclaration(s) => s.base.loc.clone(), + Statement::DeclareVariable(s) => s.base.loc.clone(), + Statement::DeclareFunction(s) => s.base.loc.clone(), + Statement::DeclareClass(s) => s.base.loc.clone(), + Statement::DeclareModule(s) => s.base.loc.clone(), + Statement::DeclareModuleExports(s) => s.base.loc.clone(), + Statement::DeclareExportDeclaration(s) => s.base.loc.clone(), + Statement::DeclareExportAllDeclaration(s) => s.base.loc.clone(), + Statement::DeclareInterface(s) => s.base.loc.clone(), + Statement::DeclareTypeAlias(s) => s.base.loc.clone(), + Statement::DeclareOpaqueType(s) => s.base.loc.clone(), + Statement::EnumDeclaration(s) => s.base.loc.clone(), + }; + convert_opt_loc(&loc) +} + +/// Collect binding names from a pattern that are declared in the given scope. +fn collect_binding_names_from_pattern( + pattern: &react_compiler_ast::patterns::PatternLike, + scope_id: react_compiler_ast::scope::ScopeId, + scope_info: &ScopeInfo, + out: &mut HashSet, +) { + use react_compiler_ast::patterns::PatternLike; + match pattern { + PatternLike::Identifier(id) => { + if let Some(&binding_id) = scope_info.scopes[scope_id.0 as usize] + .bindings + .get(&id.name) + { + out.insert(binding_id); + } + } + PatternLike::ObjectPattern(obj) => { + for prop in &obj.properties { + match prop { + react_compiler_ast::patterns::ObjectPatternProperty::ObjectProperty(p) => { + collect_binding_names_from_pattern(&p.value, scope_id, scope_info, out); + } + react_compiler_ast::patterns::ObjectPatternProperty::RestElement(r) => { + collect_binding_names_from_pattern(&r.argument, scope_id, scope_info, out); + } + } + } + } + PatternLike::ArrayPattern(arr) => { + for elem in &arr.elements { + if let Some(e) = elem { + collect_binding_names_from_pattern(e, scope_id, scope_info, out); + } + } + } + PatternLike::AssignmentPattern(assign) => { + collect_binding_names_from_pattern(&assign.left, scope_id, scope_info, out); + } + PatternLike::RestElement(rest) => { + collect_binding_names_from_pattern(&rest.argument, scope_id, scope_info, out); + } + PatternLike::MemberExpression(_) => {} + } +} + +// ============================================================================= +// lower_block_statement (with hoisting) +// ============================================================================= + +/// Lower a BlockStatement with hoisting support. +/// +/// Implements the TS BlockStatement hoisting pass: identifies forward +/// references to block-scoped bindings and emits DeclareContext instructions to +/// hoist them. +fn lower_block_statement( + builder: &mut HirBuilder, + block: &react_compiler_ast::statements::BlockStatement, +) -> Result<(), CompilerError> { + // Errors from lower_block_statement_inner are already recorded on the + // environment by record_error, so we intentionally drop the Result here + // to avoid double-recording diagnostics. + let _ = lower_block_statement_inner(builder, block, None); + Ok(()) +} + +fn lower_block_statement_with_scope( + builder: &mut HirBuilder, + block: &react_compiler_ast::statements::BlockStatement, + scope_override: react_compiler_ast::scope::ScopeId, +) -> Result<(), CompilerError> { + // Errors from lower_block_statement_inner are already recorded on the + // environment by record_error, so we intentionally drop the Result here + // to avoid double-recording diagnostics. + let _ = lower_block_statement_inner(builder, block, Some(scope_override)); + Ok(()) +} + +fn lower_block_statement_inner( + builder: &mut HirBuilder, + block: &react_compiler_ast::statements::BlockStatement, + scope_override: Option, +) -> Result<(), CompilerDiagnostic> { + use react_compiler_ast::{scope::BindingKind as AstBindingKind, statements::Statement}; + + // Look up the block's scope to identify hoistable bindings. + // Use the scope override if provided (for function body blocks that share the + // function's scope). + let block_scope_id = scope_override.or_else(|| { + block + .base + .start + .and_then(|start| builder.scope_info().node_to_scope.get(&start).copied()) + }); + + let scope_id = match block_scope_id { + Some(id) => id, + None => { + // No scope found for this block, just lower statements normally + for body_stmt in &block.body { + lower_statement(builder, body_stmt, None)?; + } + return Ok(()); + } + }; + + // Collect hoistable bindings from this scope (non-param bindings). + // Exclude bindings whose declaration_type is "FunctionExpression" since named + // function expression names are local to the expression and should never be + // hoisted. + let hoistable: Vec<(BindingId, String, AstBindingKind, String, Option)> = builder + .scope_info() + .scope_bindings(scope_id) + .filter(|b| { + !matches!(b.kind, AstBindingKind::Param) + && b.declaration_type != "FunctionExpression" + // Skip type-only declarations (TypeAlias, OpaqueType, InterfaceDeclaration, etc.) + && !matches!(b.declaration_type.as_str(), + "TypeAlias" | "OpaqueType" | "InterfaceDeclaration" + | "DeclareVariable" | "DeclareFunction" | "DeclareClass" + | "DeclareModule" | "DeclareInterface" | "DeclareOpaqueType" + | "TSTypeAliasDeclaration" | "TSInterfaceDeclaration" + | "TSEnumDeclaration" | "TSModuleDeclaration" + ) + }) + .map(|b| { + ( + b.id, + b.name.clone(), + b.kind.clone(), + b.declaration_type.clone(), + b.declaration_start, + ) + }) + .collect(); + + if hoistable.is_empty() { + // No hoistable bindings, just lower statements normally + for body_stmt in &block.body { + lower_statement(builder, body_stmt, None)?; + } + return Ok(()); + } + + // Track which bindings have been "declared" (their declaration statement has + // been seen) + let mut declared: HashSet = HashSet::new(); + + for body_stmt in &block.body { + let stmt_start = statement_start(body_stmt).unwrap_or(0); + let stmt_end = statement_end(body_stmt).unwrap_or(u32::MAX); + let is_function_decl = matches!(body_stmt, Statement::FunctionDeclaration(_)); + + // Check if statement contains nested function scopes + let has_nested_functions = is_function_decl || { + let scope_info = builder.scope_info(); + scope_info.node_to_scope.iter().any(|(&pos, &sid)| { + pos > stmt_start + && pos < stmt_end + && matches!(scope_info.scopes[sid.0 as usize].kind, ScopeKind::Function) + }) + }; + + // Find references to not-yet-declared hoistable bindings within this statement + struct HoistInfo { + binding_id: BindingId, + name: String, + kind: AstBindingKind, + declaration_type: String, + first_ref_pos: u32, + } + let mut will_hoist: Vec = Vec::new(); + + for (binding_id, name, kind, decl_type, decl_start) in &hoistable { + if declared.contains(binding_id) { + continue; + } + + // Find the first reference (not declaration) to this binding in the statement's + // range. Exclude JSX identifier references since TS hoisting + // traversal only visits Identifier nodes, not JSXIdentifier nodes. + let first_ref = builder + .scope_info() + .reference_to_binding + .iter() + .filter(|(ref_start, ref_binding_id)| { + **ref_start >= stmt_start + && **ref_start < stmt_end + && **ref_binding_id == *binding_id + && Some(**ref_start) != *decl_start + && !builder.is_jsx_identifier(**ref_start) + }) + .map(|(ref_start, _)| *ref_start) + .min(); + + if let Some(first_ref_pos) = first_ref { + // Hoist if: (1) binding is "hoisted" kind (function declaration), or + // (2) reference is inside a nested function + let should_hoist = matches!(kind, AstBindingKind::Hoisted) || has_nested_functions; + if should_hoist { + will_hoist.push(HoistInfo { + binding_id: *binding_id, + name: name.clone(), + kind: kind.clone(), + declaration_type: decl_type.clone(), + first_ref_pos, + }); + } + } + } + + // Sort by first reference position to match TS traversal order + will_hoist.sort_by_key(|h| h.first_ref_pos); + + // Emit DeclareContext for hoisted bindings + for info in &will_hoist { + if builder + .environment() + .is_hoisted_identifier(info.binding_id.0) + { + continue; + } + + let hoist_kind = match info.kind { + AstBindingKind::Const | AstBindingKind::Var => InstructionKind::HoistedConst, + AstBindingKind::Let => InstructionKind::HoistedLet, + AstBindingKind::Hoisted => InstructionKind::HoistedFunction, + _ => { + if info.declaration_type == "FunctionDeclaration" { + InstructionKind::HoistedFunction + } else if info.declaration_type == "VariableDeclarator" { + // Unsupported hoisting for this declaration kind + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Handle non-const declarations for hoisting".to_string(), + description: Some(format!( + "variable \"{}\" declared with {:?}", + info.name, info.kind + )), + loc: None, + suggestions: None, + })?; + continue; + } else { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Unsupported declaration type for hoisting".to_string(), + description: Some(format!( + "variable \"{}\" declared with {}", + info.name, info.declaration_type + )), + loc: None, + suggestions: None, + })?; + continue; + } + } + }; + + // Look up the reference location for the DeclareContext instruction + let ref_loc = builder.get_identifier_loc(info.first_ref_pos); + let identifier = builder.resolve_binding(&info.name, info.binding_id)?; + let place = Place { + effect: Effect::Unknown, + identifier, + reactive: false, + loc: ref_loc.clone(), + }; + lower_value_to_temporary( + builder, + InstructionValue::DeclareContext { + lvalue: LValue { + kind: hoist_kind, + place, + }, + loc: ref_loc, + }, + )?; + builder + .environment_mut() + .add_hoisted_identifier(info.binding_id.0); + // Hoisted identifiers also become context identifiers (matching TS + // addHoistedIdentifier) + builder.add_context_identifier(info.binding_id); + } + + // After processing the statement, mark any bindings it declares as "seen". + // This must cover all statement types that can introduce bindings. + match body_stmt { + Statement::FunctionDeclaration(func) => { + if let Some(id) = &func.id { + if let Some(&binding_id) = builder.scope_info().scopes[scope_id.0 as usize] + .bindings + .get(&id.name) + { + declared.insert(binding_id); + } + } + } + Statement::VariableDeclaration(var_decl) => { + for decl in &var_decl.declarations { + collect_binding_names_from_pattern( + &decl.id, + scope_id, + builder.scope_info(), + &mut declared, + ); + } + } + Statement::ClassDeclaration(cls) => { + if let Some(id) = &cls.id { + if let Some(&binding_id) = builder.scope_info().scopes[scope_id.0 as usize] + .bindings + .get(&id.name) + { + declared.insert(binding_id); + } + } + } + _ => { + // For other statement types (e.g. ForStatement with + // VariableDeclaration in init), we rely on the + // reference_to_binding check for forward references. + // Any bindings declared by child scopes won't be in this + // block's scope anyway. + } + } + + lower_statement(builder, body_stmt, None)?; + } + Ok(()) +} + +// ============================================================================= +// lower_statement +// ============================================================================= + +fn lower_statement( + builder: &mut HirBuilder, + stmt: &react_compiler_ast::statements::Statement, + label: Option<&str>, +) -> Result<(), CompilerDiagnostic> { + use react_compiler_ast::statements::Statement; + + match stmt { + Statement::EmptyStatement(_) => { + // no-op + } + Statement::DebuggerStatement(dbg) => { + let loc = convert_opt_loc(&dbg.base.loc); + let value = InstructionValue::Debugger { loc }; + lower_value_to_temporary(builder, value)?; + } + Statement::ExpressionStatement(expr_stmt) => { + lower_expression_to_temporary(builder, &expr_stmt.expression)?; + } + Statement::ReturnStatement(ret) => { + let loc = convert_opt_loc(&ret.base.loc); + let value = if let Some(arg) = &ret.argument { + lower_expression_to_temporary(builder, arg)? + } else { + let undefined_value = InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc: None, + }; + lower_value_to_temporary(builder, undefined_value)? + }; + let fallthrough = builder.reserve(BlockKind::Block); + builder.terminate_with_continuation( + Terminal::Return { + value, + return_variant: ReturnVariant::Explicit, + id: EvaluationOrder(0), + loc, + effects: None, + }, + fallthrough, + ); + } + Statement::ThrowStatement(throw) => { + let loc = convert_opt_loc(&throw.base.loc); + let value = lower_expression_to_temporary(builder, &throw.argument)?; + + // Check for throw handler (try/catch) + if let Some(_handler) = builder.resolve_throw_handler() { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerStatement) Support ThrowStatement inside of try/catch" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + } + + let fallthrough = builder.reserve(BlockKind::Block); + builder.terminate_with_continuation( + Terminal::Throw { + value, + id: EvaluationOrder(0), + loc, + }, + fallthrough, + ); + } + Statement::BlockStatement(block) => { + lower_block_statement(builder, block)?; + } + Statement::VariableDeclaration(var_decl) => { + use react_compiler_ast::{patterns::PatternLike, statements::VariableDeclarationKind}; + if matches!(var_decl.kind, VariableDeclarationKind::Var) { + builder.record_error(CompilerErrorDetail { + reason: "(BuildHIR::lowerStatement) Handle var kinds in VariableDeclaration" + .to_string(), + category: ErrorCategory::Todo, + loc: convert_opt_loc(&var_decl.base.loc), + description: None, + suggestions: None, + })?; + // Treat `var` as `let` so references to the variable don't + // break + } + let kind = match var_decl.kind { + VariableDeclarationKind::Let | VariableDeclarationKind::Var => InstructionKind::Let, + VariableDeclarationKind::Const | VariableDeclarationKind::Using => { + InstructionKind::Const + } + }; + for declarator in &var_decl.declarations { + let stmt_loc = convert_opt_loc(&var_decl.base.loc); + if let Some(init) = &declarator.init { + let value = lower_expression_to_temporary(builder, init)?; + let assign_style = match &declarator.id { + PatternLike::ObjectPattern(_) | PatternLike::ArrayPattern(_) => { + AssignmentStyle::Destructure + } + _ => AssignmentStyle::Assignment, + }; + lower_assignment(builder, stmt_loc, kind, &declarator.id, value, assign_style)?; + } else if let PatternLike::Identifier(id) = &declarator.id { + // No init: emit DeclareLocal or DeclareContext + let id_loc = convert_opt_loc(&id.base.loc); + let binding = builder.resolve_identifier( + &id.name, + id.base.start.unwrap_or(0), + id_loc.clone(), + )?; + match binding { + VariableBinding::Identifier { identifier, .. } => { + // Update the identifier's loc to the declaration site + // (it may have been first created at a reference site during hoisting) + builder.set_identifier_declaration_loc(identifier, &id_loc); + let place = Place { + identifier, + effect: Effect::Unknown, + reactive: false, + loc: id_loc.clone(), + }; + if builder.is_context_identifier(&id.name, id.base.start.unwrap_or(0)) { + if kind == InstructionKind::Const { + builder.record_error(CompilerErrorDetail { + reason: "Expect `const` declaration not to be reassigned" + .to_string(), + category: ErrorCategory::Syntax, + loc: id_loc.clone(), + description: None, + suggestions: None, + })?; + } + lower_value_to_temporary( + builder, + InstructionValue::DeclareContext { + lvalue: LValue { + kind: InstructionKind::Let, + place, + }, + loc: id_loc, + }, + )?; + } else { + let type_annotation = + extract_type_annotation_name(&id.type_annotation); + lower_value_to_temporary( + builder, + InstructionValue::DeclareLocal { + lvalue: LValue { kind, place }, + type_annotation, + loc: id_loc, + }, + )?; + } + } + _ => { + builder.record_error(CompilerErrorDetail { + reason: "Could not find binding for declaration".to_string(), + category: ErrorCategory::Invariant, + loc: id_loc, + description: None, + suggestions: None, + })?; + } + } + } else { + builder.record_error(CompilerErrorDetail { + reason: "Expected variable declaration to be an identifier if no \ + initializer was provided" + .to_string(), + category: ErrorCategory::Syntax, + loc: convert_opt_loc(&declarator.base.loc), + description: None, + suggestions: None, + })?; + } + } + } + Statement::BreakStatement(brk) => { + let loc = convert_opt_loc(&brk.base.loc); + let label_name = brk.label.as_ref().map(|l| l.name.as_str()); + let target = builder.lookup_break(label_name)?; + let fallthrough = builder.reserve(BlockKind::Block); + builder.terminate_with_continuation( + Terminal::Goto { + block: target, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc, + }, + fallthrough, + ); + } + Statement::ContinueStatement(cont) => { + let loc = convert_opt_loc(&cont.base.loc); + let label_name = cont.label.as_ref().map(|l| l.name.as_str()); + let target = builder.lookup_continue(label_name)?; + let fallthrough = builder.reserve(BlockKind::Block); + builder.terminate_with_continuation( + Terminal::Goto { + block: target, + variant: GotoVariant::Continue, + id: EvaluationOrder(0), + loc, + }, + fallthrough, + ); + } + Statement::IfStatement(if_stmt) => { + let loc = convert_opt_loc(&if_stmt.base.loc); + // Block for code following the if + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + + // Block for the consequent (if the test is truthy) + let consequent_loc = statement_loc(&if_stmt.consequent); + let consequent_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + lower_statement(builder, &if_stmt.consequent, None)?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: consequent_loc, + }) + })?; + + // Block for the alternate (if the test is not truthy) + let alternate_block = if let Some(alternate) = &if_stmt.alternate { + let alternate_loc = statement_loc(alternate); + builder.try_enter(BlockKind::Block, |builder, _block_id| { + lower_statement(builder, alternate, None)?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: alternate_loc, + }) + })? + } else { + // If there is no else clause, use the continuation directly + continuation_id + }; + + let test = lower_expression_to_temporary(builder, &if_stmt.test)?; + builder.terminate_with_continuation( + Terminal::If { + test, + consequent: consequent_block, + alternate: alternate_block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + Statement::ForStatement(for_stmt) => { + let loc = convert_opt_loc(&for_stmt.base.loc); + + let test_block = builder.reserve(BlockKind::Loop); + let test_block_id = test_block.id; + // Block for code following the loop + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + + // Init block: lower init expression/declaration, then goto test + let init_block = builder.try_enter(BlockKind::Loop, |builder, _block_id| { + let init_loc = match &for_stmt.init { + None => { + // No init expression (e.g., `for (; ...)`), add a placeholder + let placeholder = InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc: loc.clone(), + }; + lower_value_to_temporary(builder, placeholder)?; + loc.clone() + } + Some(init) => match init.as_ref() { + react_compiler_ast::statements::ForInit::VariableDeclaration(var_decl) => { + let init_loc = convert_opt_loc(&var_decl.base.loc); + lower_statement( + builder, + &Statement::VariableDeclaration(var_decl.clone()), + None, + )?; + init_loc + } + react_compiler_ast::statements::ForInit::Expression(expr) => { + let init_loc = expression_loc(expr); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerStatement) Handle non-variable \ + initialization in ForStatement" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + lower_expression_to_temporary(builder, expr)?; + init_loc + } + }, + }; + Ok(Terminal::Goto { + block: test_block_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: init_loc, + }) + })?; + + // Update block (optional) + let update_block_id = if let Some(update) = &for_stmt.update { + let update_loc = expression_loc(update); + Some(builder.try_enter(BlockKind::Loop, |builder, _block_id| { + lower_expression_to_temporary(builder, update)?; + Ok(Terminal::Goto { + block: test_block_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: update_loc, + }) + })?) + } else { + None + }; + + // Loop body block + let continue_target = update_block_id.unwrap_or(test_block_id); + let body_loc = statement_loc(&for_stmt.body); + let body_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.loop_scope( + label.map(|s| s.to_string()), + continue_target, + continuation_id, + |builder| { + lower_statement(builder, &for_stmt.body, None)?; + Ok(Terminal::Goto { + block: continue_target, + variant: GotoVariant::Continue, + id: EvaluationOrder(0), + loc: body_loc, + }) + }, + ) + })?; + + // Emit For terminal, then fill in the test block + builder.terminate_with_continuation( + Terminal::For { + init: init_block, + test: test_block_id, + update: update_block_id, + loop_block: body_block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + test_block, + ); + + // Fill in the test block + if let Some(test_expr) = &for_stmt.test { + let test = lower_expression_to_temporary(builder, test_expr)?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: body_block, + alternate: continuation_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + } else { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerStatement) Handle empty test in ForStatement" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + // Treat `for(;;)` as `while(true)` to keep the builder state consistent + let true_val = InstructionValue::Primitive { + value: PrimitiveValue::Boolean(true), + loc: loc.clone(), + }; + let test = lower_value_to_temporary(builder, true_val)?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: body_block, + alternate: continuation_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + } + Statement::WhileStatement(while_stmt) => { + let loc = convert_opt_loc(&while_stmt.base.loc); + // Block used to evaluate whether to (re)enter or exit the loop + let conditional_block = builder.reserve(BlockKind::Loop); + let conditional_id = conditional_block.id; + // Block for code following the loop + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + + // Loop body + let body_loc = statement_loc(&while_stmt.body); + let loop_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.loop_scope( + label.map(|s| s.to_string()), + conditional_id, + continuation_id, + |builder| { + lower_statement(builder, &while_stmt.body, None)?; + Ok(Terminal::Goto { + block: conditional_id, + variant: GotoVariant::Continue, + id: EvaluationOrder(0), + loc: body_loc, + }) + }, + ) + })?; + + // Emit While terminal, jumping to the conditional block + builder.terminate_with_continuation( + Terminal::While { + test: conditional_id, + loop_block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + conditional_block, + ); + + // Fill in the conditional block: lower test, branch + let test = lower_expression_to_temporary(builder, &while_stmt.test)?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: loop_block, + alternate: continuation_id, + fallthrough: conditional_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + Statement::DoWhileStatement(do_while_stmt) => { + let loc = convert_opt_loc(&do_while_stmt.base.loc); + // Block used to evaluate whether to (re)enter or exit the loop + let conditional_block = builder.reserve(BlockKind::Loop); + let conditional_id = conditional_block.id; + // Block for code following the loop + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + + // Loop body, executed at least once unconditionally prior to exit + let body_loc = statement_loc(&do_while_stmt.body); + let loop_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.loop_scope( + label.map(|s| s.to_string()), + conditional_id, + continuation_id, + |builder| { + lower_statement(builder, &do_while_stmt.body, None)?; + Ok(Terminal::Goto { + block: conditional_id, + variant: GotoVariant::Continue, + id: EvaluationOrder(0), + loc: body_loc, + }) + }, + ) + })?; + + // Jump to the conditional block + builder.terminate_with_continuation( + Terminal::DoWhile { + loop_block, + test: conditional_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + conditional_block, + ); + + // Fill in the conditional block: lower test, branch + let test = lower_expression_to_temporary(builder, &do_while_stmt.test)?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: loop_block, + alternate: continuation_id, + fallthrough: conditional_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + Statement::ForInStatement(for_in) => { + let loc = convert_opt_loc(&for_in.base.loc); + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + let init_block = builder.reserve(BlockKind::Loop); + let init_block_id = init_block.id; + + let body_loc = statement_loc(&for_in.body); + let loop_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.loop_scope( + label.map(|s| s.to_string()), + init_block_id, + continuation_id, + |builder| { + lower_statement(builder, &for_in.body, None)?; + Ok(Terminal::Goto { + block: init_block_id, + variant: GotoVariant::Continue, + id: EvaluationOrder(0), + loc: body_loc, + }) + }, + ) + })?; + + let value = lower_expression_to_temporary(builder, &for_in.right)?; + builder.terminate_with_continuation( + Terminal::ForIn { + init: init_block_id, + loop_block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + init_block, + ); + + // Lower the init: NextPropertyOf + assignment + let left_loc = match for_in.left.as_ref() { + react_compiler_ast::statements::ForInOfLeft::VariableDeclaration(var_decl) => { + convert_opt_loc(&var_decl.base.loc).or(loc.clone()) + } + react_compiler_ast::statements::ForInOfLeft::Pattern(pat) => { + pattern_like_hir_loc(pat).or(loc.clone()) + } + }; + let next_property = lower_value_to_temporary( + builder, + InstructionValue::NextPropertyOf { + value, + loc: left_loc.clone(), + }, + )?; + + let assign_result = match for_in.left.as_ref() { + react_compiler_ast::statements::ForInOfLeft::VariableDeclaration(var_decl) => { + if var_decl.declarations.len() != 1 { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Invariant, + reason: format!( + "Expected only one declaration in ForInStatement init, got {}", + var_decl.declarations.len() + ), + description: None, + loc: left_loc.clone(), + suggestions: None, + })?; + } + if let Some(declarator) = var_decl.declarations.first() { + lower_assignment( + builder, + left_loc.clone(), + InstructionKind::Let, + &declarator.id, + next_property.clone(), + AssignmentStyle::Assignment, + )? + } else { + None + } + } + react_compiler_ast::statements::ForInOfLeft::Pattern(pattern) => lower_assignment( + builder, + left_loc.clone(), + InstructionKind::Reassign, + pattern, + next_property.clone(), + AssignmentStyle::Assignment, + )?, + }; + // Use the assign result (StoreLocal temp) as the test, matching TS behavior + let test_value = assign_result.unwrap_or(next_property); + let test = lower_value_to_temporary( + builder, + InstructionValue::LoadLocal { + place: test_value, + loc: left_loc.clone(), + }, + )?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: loop_block, + alternate: continuation_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + } + Statement::ForOfStatement(for_of) => { + let loc = convert_opt_loc(&for_of.base.loc); + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + let init_block = builder.reserve(BlockKind::Loop); + let init_block_id = init_block.id; + let test_block = builder.reserve(BlockKind::Loop); + let test_block_id = test_block.id; + + if for_of.is_await { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerStatement) Handle for-await loops".to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(()); + } + + let body_loc = statement_loc(&for_of.body); + let loop_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.loop_scope( + label.map(|s| s.to_string()), + init_block_id, + continuation_id, + |builder| { + lower_statement(builder, &for_of.body, None)?; + Ok(Terminal::Goto { + block: init_block_id, + variant: GotoVariant::Continue, + id: EvaluationOrder(0), + loc: body_loc, + }) + }, + ) + })?; + + let value = lower_expression_to_temporary(builder, &for_of.right)?; + builder.terminate_with_continuation( + Terminal::ForOf { + init: init_block_id, + test: test_block_id, + loop_block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + init_block, + ); + + // Init block: GetIterator, goto test + let iterator = lower_value_to_temporary( + builder, + InstructionValue::GetIterator { + collection: value.clone(), + loc: value.loc.clone(), + }, + )?; + builder.terminate_with_continuation( + Terminal::Goto { + block: test_block_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + test_block, + ); + + // Test block: IteratorNext, assign, branch + let left_loc = match for_of.left.as_ref() { + react_compiler_ast::statements::ForInOfLeft::VariableDeclaration(var_decl) => { + convert_opt_loc(&var_decl.base.loc).or(loc.clone()) + } + react_compiler_ast::statements::ForInOfLeft::Pattern(pat) => { + pattern_like_hir_loc(pat).or(loc.clone()) + } + }; + let advance_iterator = lower_value_to_temporary( + builder, + InstructionValue::IteratorNext { + iterator: iterator.clone(), + collection: value.clone(), + loc: left_loc.clone(), + }, + )?; + + let assign_result = match for_of.left.as_ref() { + react_compiler_ast::statements::ForInOfLeft::VariableDeclaration(var_decl) => { + if var_decl.declarations.len() != 1 { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Invariant, + reason: format!( + "Expected only one declaration in ForOfStatement init, got {}", + var_decl.declarations.len() + ), + description: None, + loc: left_loc.clone(), + suggestions: None, + })?; + } + if let Some(declarator) = var_decl.declarations.first() { + lower_assignment( + builder, + left_loc.clone(), + InstructionKind::Let, + &declarator.id, + advance_iterator.clone(), + AssignmentStyle::Assignment, + )? + } else { + None + } + } + react_compiler_ast::statements::ForInOfLeft::Pattern(pattern) => lower_assignment( + builder, + left_loc.clone(), + InstructionKind::Reassign, + pattern, + advance_iterator.clone(), + AssignmentStyle::Assignment, + )?, + }; + // Use the assign result (StoreLocal temp) as the test, matching TS behavior + let test_value = assign_result.unwrap_or(advance_iterator); + let test = lower_value_to_temporary( + builder, + InstructionValue::LoadLocal { + place: test_value, + loc: left_loc.clone(), + }, + )?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: loop_block, + alternate: continuation_id, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + } + Statement::SwitchStatement(switch_stmt) => { + let loc = convert_opt_loc(&switch_stmt.base.loc); + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + + // Iterate through cases in reverse order so that previous blocks can + // fallthrough to successors + let mut fallthrough = continuation_id; + let mut cases: Vec = Vec::new(); + let mut has_default = false; + + for ii in (0..switch_stmt.cases.len()).rev() { + let case = &switch_stmt.cases[ii]; + let case_loc = convert_opt_loc(&case.base.loc); + + if case.test.is_none() { + if has_default { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Syntax, + reason: "Expected at most one `default` branch in a switch statement" + .to_string(), + description: None, + loc: case_loc.clone(), + suggestions: None, + })?; + break; + } + has_default = true; + } + + let fallthrough_target = fallthrough; + let block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.switch_scope(label.map(|s| s.to_string()), continuation_id, |builder| { + for consequent in &case.consequent { + lower_statement(builder, consequent, None)?; + } + Ok(Terminal::Goto { + block: fallthrough_target, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: case_loc.clone(), + }) + }) + })?; + + let test = if let Some(test_expr) = &case.test { + Some(lower_reorderable_expression(builder, test_expr)?) + } else { + None + }; + + cases.push(Case { test, block }); + fallthrough = block; + } + + // Reverse back to original order + cases.reverse(); + + // If no default case, add one that jumps to continuation + if !has_default { + cases.push(Case { + test: None, + block: continuation_id, + }); + } + + let test = lower_expression_to_temporary(builder, &switch_stmt.discriminant)?; + builder.terminate_with_continuation( + Terminal::Switch { + test, + cases, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + Statement::TryStatement(try_stmt) => { + let loc = convert_opt_loc(&try_stmt.base.loc); + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + + let handler_clause = match &try_stmt.handler { + Some(h) => h, + None => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerStatement) Handle TryStatement without a catch \ + clause" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(()); + } + }; + + if try_stmt.finalizer.is_some() { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "(BuildHIR::lowerStatement) Handle TryStatement with a finalizer \ + ('finally') clause" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + } + + // Set up handler binding if catch has a param + let handler_binding_info: Option<(Place, react_compiler_ast::patterns::PatternLike)> = + if let Some(param) = &handler_clause.param { + // Check for destructuring in catch clause params. + // Match TS behavior: Babel doesn't register destructured catch bindings + // in its scope, so resolveIdentifier fails and records an invariant error. + let is_destructuring = matches!( + param, + react_compiler_ast::patterns::PatternLike::ObjectPattern(_) + | react_compiler_ast::patterns::PatternLike::ArrayPattern(_) + ); + if is_destructuring { + // Iterate the pattern to find all identifier locs for error reporting + fn collect_identifier_locs( + pat: &react_compiler_ast::patterns::PatternLike, + locs: &mut Vec>, + ) { + match pat { + react_compiler_ast::patterns::PatternLike::Identifier(id) => { + locs.push(convert_opt_loc(&id.base.loc)); + } + react_compiler_ast::patterns::PatternLike::ObjectPattern(obj) => { + for prop in &obj.properties { + match prop { + react_compiler_ast::patterns::ObjectPatternProperty::ObjectProperty(p) => { + collect_identifier_locs(&p.value, locs); + } + react_compiler_ast::patterns::ObjectPatternProperty::RestElement(r) => { + collect_identifier_locs(&r.argument, locs); + } + } + } + } + react_compiler_ast::patterns::PatternLike::ArrayPattern(arr) => { + for elem in &arr.elements { + if let Some(e) = elem { + collect_identifier_locs(e, locs); + } + } + } + _ => {} + } + } + let mut id_locs = Vec::new(); + collect_identifier_locs(param, &mut id_locs); + for id_loc in id_locs { + builder.record_error(CompilerErrorDetail { + reason: "(BuildHIR::lowerAssignment) Could not find binding for \ + declaration." + .to_string(), + category: ErrorCategory::Invariant, + loc: id_loc, + description: None, + suggestions: None, + })?; + } + None + } else { + let param_loc = convert_opt_loc(&pattern_like_loc(param)); + let id = builder.make_temporary(param_loc.clone()); + promote_temporary(builder, id); + let place = Place { + identifier: id, + effect: Effect::Unknown, + reactive: false, + loc: param_loc.clone(), + }; + // Emit DeclareLocal for the catch binding + lower_value_to_temporary( + builder, + InstructionValue::DeclareLocal { + lvalue: LValue { + kind: InstructionKind::Catch, + place: place.clone(), + }, + type_annotation: None, + loc: param_loc, + }, + )?; + Some((place, param.clone())) + } + } else { + None + }; + + // Create the handler (catch) block + let handler_binding_for_block = handler_binding_info.clone(); + let handler_loc = convert_opt_loc(&handler_clause.base.loc); + // Use the catch param's loc for the assignment, matching TS: + // handlerBinding.path.node.loc + let handler_param_loc = handler_clause + .param + .as_ref() + .and_then(|p| convert_opt_loc(&pattern_like_loc(p))); + let handler_block = builder.try_enter(BlockKind::Catch, |builder, _block_id| { + if let Some((ref place, ref pattern)) = handler_binding_for_block { + lower_assignment( + builder, + handler_param_loc.clone().or_else(|| handler_loc.clone()), + InstructionKind::Catch, + pattern, + place.clone(), + AssignmentStyle::Assignment, + )?; + } + // Lower the catch body using lower_block_statement to get hoisting support. + // Match TS behavior where `lowerStatement(builder, handlerPath.get('body'))` + // processes the catch body as a BlockStatement (with hoisting). + // Use the catch clause's scope since the catch body block shares + // the CatchClause scope in Babel (contains the catch param binding). + // Use the catch clause's scope (which contains the catch param binding). + // Fall back to the body block's own scope if the catch clause scope is missing. + let catch_scope = handler_clause + .base + .start + .and_then(|start| builder.scope_info().node_to_scope.get(&start).copied()) + .or_else(|| { + handler_clause.body.base.start.and_then(|start| { + builder.scope_info().node_to_scope.get(&start).copied() + }) + }); + if let Some(scope_id) = catch_scope { + lower_block_statement_with_scope(builder, &handler_clause.body, scope_id)?; + } else { + // No scope found — this shouldn't happen with well-formed Babel output. + // Fall back to plain block lowering (no hoisting) rather than panicking, + // since this is a non-critical degradation. + lower_block_statement(builder, &handler_clause.body)?; + } + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: handler_loc.clone(), + }) + })?; + + // Create the try block + let try_body_loc = convert_opt_loc(&try_stmt.block.base.loc); + let try_block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.try_enter_try_catch(handler_block, |builder| { + for stmt in &try_stmt.block.body { + lower_statement(builder, stmt, None)?; + } + Ok(()) + })?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Try, + id: EvaluationOrder(0), + loc: try_body_loc.clone(), + }) + })?; + + builder.terminate_with_continuation( + Terminal::Try { + block: try_block, + handler_binding: handler_binding_info.map(|(place, _)| place), + handler: handler_block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + Statement::LabeledStatement(labeled_stmt) => { + let label_name = &labeled_stmt.label.name; + let loc = convert_opt_loc(&labeled_stmt.base.loc); + + // Check if the body is a loop statement - if so, delegate with label + match labeled_stmt.body.as_ref() { + Statement::ForStatement(_) + | Statement::WhileStatement(_) + | Statement::DoWhileStatement(_) + | Statement::ForInStatement(_) + | Statement::ForOfStatement(_) => { + // Labeled loops are special because of continue, push the label down + lower_statement(builder, &labeled_stmt.body, Some(label_name))?; + } + _ => { + // All other statements create a continuation block to allow `break` + let continuation_block = builder.reserve(BlockKind::Block); + let continuation_id = continuation_block.id; + let body_loc = statement_loc(&labeled_stmt.body); + + let block = builder.try_enter(BlockKind::Block, |builder, _block_id| { + builder.label_scope(label_name.clone(), continuation_id, |builder| { + lower_statement(builder, &labeled_stmt.body, None)?; + Ok(()) + })?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: body_loc, + }) + })?; + + builder.terminate_with_continuation( + Terminal::Label { + block, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc, + }, + continuation_block, + ); + } + } + } + Statement::WithStatement(with_stmt) => { + let loc = convert_opt_loc(&with_stmt.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::UnsupportedSyntax, + reason: "JavaScript 'with' syntax is not supported".to_string(), + description: Some( + "'with' syntax is considered deprecated and removed from JavaScript \ + standards, consider alternatives" + .to_string(), + ), + loc: loc.clone(), + suggestions: None, + })?; + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + )?; + } + Statement::FunctionDeclaration(func_decl) => { + lower_function_declaration(builder, func_decl)?; + } + Statement::ClassDeclaration(cls) => { + let loc = convert_opt_loc(&cls.base.loc); + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::UnsupportedSyntax, + reason: "Inline `class` declarations are not supported".to_string(), + description: Some( + "Move class declarations outside of components/hooks".to_string(), + ), + loc: loc.clone(), + suggestions: None, + })?; + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: Some("ClassDeclaration".to_string()), + original_node: None, + loc, + }, + )?; + } + Statement::ImportDeclaration(_) + | Statement::ExportNamedDeclaration(_) + | Statement::ExportDefaultDeclaration(_) + | Statement::ExportAllDeclaration(_) => { + let loc = match stmt { + Statement::ImportDeclaration(s) => convert_opt_loc(&s.base.loc), + Statement::ExportNamedDeclaration(s) => convert_opt_loc(&s.base.loc), + Statement::ExportDefaultDeclaration(s) => convert_opt_loc(&s.base.loc), + Statement::ExportAllDeclaration(s) => convert_opt_loc(&s.base.loc), + _ => unreachable!(), + }; + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Syntax, + reason: "JavaScript `import` and `export` statements may only appear at the top \ + level of a module" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + )?; + } + // TypeScript/Flow declarations are type-only, skip them + Statement::TSEnumDeclaration(e) => { + let loc = convert_opt_loc(&e.base.loc); + let original_node = serde_json::to_value( + &react_compiler_ast::statements::Statement::TSEnumDeclaration(e.clone()), + ) + .ok(); + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: Some("TSEnumDeclaration".to_string()), + original_node, + loc, + }, + )?; + } + Statement::EnumDeclaration(e) => { + let loc = convert_opt_loc(&e.base.loc); + let original_node = serde_json::to_value( + &react_compiler_ast::statements::Statement::EnumDeclaration(e.clone()), + ) + .ok(); + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: Some("EnumDeclaration".to_string()), + original_node, + loc, + }, + )?; + } + // TypeScript/Flow type declarations are type-only, skip them + Statement::TSTypeAliasDeclaration(_) + | Statement::TSInterfaceDeclaration(_) + | Statement::TSModuleDeclaration(_) + | Statement::TSDeclareFunction(_) + | Statement::TypeAlias(_) + | Statement::OpaqueType(_) + | Statement::InterfaceDeclaration(_) + | Statement::DeclareVariable(_) + | Statement::DeclareFunction(_) + | Statement::DeclareClass(_) + | Statement::DeclareModule(_) + | Statement::DeclareModuleExports(_) + | Statement::DeclareExportDeclaration(_) + | Statement::DeclareExportAllDeclaration(_) + | Statement::DeclareInterface(_) + | Statement::DeclareTypeAlias(_) + | Statement::DeclareOpaqueType(_) => {} + } + Ok(()) +} + +// ============================================================================= +// lower() entry point +// ============================================================================= + +enum FunctionBody<'a> { + Block(&'a react_compiler_ast::statements::BlockStatement), + Expression(&'a react_compiler_ast::expressions::Expression), +} + +/// Main entry point: lower a function AST node into HIR. +/// +/// Receives a `FunctionNode` (discovered by the entrypoint) and lowers it to +/// HIR. The `id` parameter provides the function name (which may come from the +/// variable declarator rather than the function node itself, e.g. `const Foo = +/// () => {}`). +pub fn lower( + func: &FunctionNode<'_>, + _id: Option<&str>, + scope_info: &ScopeInfo, + env: &mut Environment, +) -> Result { + // Extract params, body, generator, is_async, loc, scope_id, and the AST + // function's own id Note: `id` param may include inferred names (e.g., from + // `const Foo = () => {}`), but the HIR function's `id` field should only + // include the function's own AST id (FunctionDeclaration.id or + // FunctionExpression.id, NOT arrow functions). + let (params, body, generator, is_async, loc, start, ast_id) = match func { + FunctionNode::FunctionDeclaration(decl) => ( + &decl.params[..], + FunctionBody::Block(&decl.body), + decl.generator, + decl.is_async, + convert_opt_loc(&decl.base.loc), + decl.base.start.unwrap_or(0), + decl.id.as_ref().map(|id| id.name.as_str()), + ), + FunctionNode::FunctionExpression(expr) => ( + &expr.params[..], + FunctionBody::Block(&expr.body), + expr.generator, + expr.is_async, + convert_opt_loc(&expr.base.loc), + expr.base.start.unwrap_or(0), + expr.id.as_ref().map(|id| id.name.as_str()), + ), + FunctionNode::ArrowFunctionExpression(arrow) => { + let body = match arrow.body.as_ref() { + react_compiler_ast::expressions::ArrowFunctionBody::BlockStatement(block) => { + FunctionBody::Block(block) + } + react_compiler_ast::expressions::ArrowFunctionBody::Expression(expr) => { + FunctionBody::Expression(expr) + } + }; + ( + &arrow.params[..], + body, + arrow.generator, + arrow.is_async, + convert_opt_loc(&arrow.base.loc), + arrow.base.start.unwrap_or(0), + None, // Arrow functions never have an AST id + ) + } + }; + + let scope_id = scope_info + .node_to_scope + .get(&start) + .copied() + .unwrap_or(scope_info.program_scope); + + // Pre-compute context identifiers: variables captured across function + // boundaries + let context_identifiers = find_context_identifiers(func, scope_info); + + // Build identifier location index from the AST (replaces serialized + // referenceLocs/jsxReferencePositions) + let identifier_locs = build_identifier_loc_index(func, scope_info); + + // For top-level functions, context is empty (no captured refs) + let context_map: IndexMap> = + IndexMap::new(); + + let (hir_func, _used_names, _child_bindings) = lower_inner( + params, + body, + ast_id, + generator, + is_async, + loc, + scope_info, + env, + None, // no pre-existing bindings for top-level + None, // no pre-existing used_names for top-level + context_map, + scope_id, + scope_id, // component_scope = function_scope for top-level + &context_identifiers, + true, // is_top_level + &identifier_locs, + )?; + + Ok(hir_func) +} + +// ============================================================================= +// Stubs for future milestones +// ============================================================================= + +/// Result of resolving an identifier for assignment. +enum IdentifierForAssignment { + /// A local place (identifier binding) + Place(Place), + /// A global variable (non-local, non-import) + Global { name: String }, +} + +/// Resolve an identifier for use as an assignment target. +/// Returns None if the binding could not be found (error recorded). +fn lower_identifier_for_assignment( + builder: &mut HirBuilder, + loc: Option, + ident_loc: Option, + kind: InstructionKind, + name: &str, + start: u32, +) -> Result, CompilerError> { + let binding = builder.resolve_identifier(name, start, ident_loc.clone())?; + match binding { + VariableBinding::Identifier { + identifier, + binding_kind, + .. + } => { + // Set the identifier's loc from the declaration site (not for reassignments, + // which should keep the original declaration loc) + if kind != InstructionKind::Reassign { + builder.set_identifier_declaration_loc(identifier, &ident_loc); + } + if binding_kind == BindingKind::Const && kind == InstructionKind::Reassign { + builder.record_error(CompilerErrorDetail { + reason: "Cannot reassign a `const` variable".to_string(), + category: ErrorCategory::Syntax, + loc: loc.clone(), + description: Some(format!("`{}` is declared as const", name)), + suggestions: None, + })?; + return Ok(None); + } + Ok(Some(IdentifierForAssignment::Place(Place { + identifier, + effect: Effect::Unknown, + reactive: false, + loc, + }))) + } + VariableBinding::Global { name: gname } => { + if kind == InstructionKind::Reassign { + Ok(Some(IdentifierForAssignment::Global { name: gname })) + } else { + builder.record_error(CompilerErrorDetail { + reason: "Could not find binding for declaration".to_string(), + category: ErrorCategory::Invariant, + loc, + description: None, + suggestions: None, + })?; + Ok(None) + } + } + _ => { + // Import bindings can't be assigned to + if kind == InstructionKind::Reassign { + Ok(Some(IdentifierForAssignment::Global { + name: name.to_string(), + })) + } else { + builder.record_error(CompilerErrorDetail { + reason: "Could not find binding for declaration".to_string(), + category: ErrorCategory::Invariant, + loc, + description: None, + suggestions: None, + })?; + Ok(None) + } + } + } +} + +fn lower_assignment( + builder: &mut HirBuilder, + loc: Option, + kind: InstructionKind, + target: &react_compiler_ast::patterns::PatternLike, + value: Place, + assignment_style: AssignmentStyle, +) -> Result, CompilerError> { + use react_compiler_ast::patterns::PatternLike; + + match target { + PatternLike::Identifier(id) => { + let id_loc = convert_opt_loc(&id.base.loc); + let result = lower_identifier_for_assignment( + builder, + loc.clone(), + id_loc, + kind, + &id.name, + id.base.start.unwrap_or(0), + )?; + match result { + None => { + // Error already recorded + return Ok(None); + } + Some(IdentifierForAssignment::Global { name }) => { + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreGlobal { name, value, loc }, + )?; + return Ok(Some(temp)); + } + Some(IdentifierForAssignment::Place(place)) => { + let start = id.base.start.unwrap_or(0); + if builder.is_context_identifier(&id.name, start) { + // Check if the binding is hoisted before flagging const reassignment + let is_hoisted = builder + .scope_info() + .resolve_reference(start) + .map(|b| builder.environment().is_hoisted_identifier(b.id.0)) + .unwrap_or(false); + if kind == InstructionKind::Const && !is_hoisted { + builder.record_error(CompilerErrorDetail { + reason: "Expected `const` declaration not to be reassigned" + .to_string(), + category: ErrorCategory::Syntax, + loc: loc.clone(), + suggestions: None, + description: None, + })?; + } + if kind != InstructionKind::Const + && kind != InstructionKind::Reassign + && kind != InstructionKind::Let + && kind != InstructionKind::Function + { + builder.record_error(CompilerErrorDetail { + reason: "Unexpected context variable kind".to_string(), + category: ErrorCategory::Syntax, + loc: loc.clone(), + suggestions: None, + description: None, + })?; + let temp = lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + )?; + return Ok(Some(temp)); + } + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreContext { + lvalue: LValue { place, kind }, + value, + loc, + }, + )?; + return Ok(Some(temp)); + } else { + let type_annotation = extract_type_annotation_name(&id.type_annotation); + let temp = lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { place, kind }, + value, + type_annotation, + loc, + }, + )?; + return Ok(Some(temp)); + } + } + } + } + + PatternLike::MemberExpression(member) => { + // MemberExpression may only appear in an assignment expression (Reassign) + if kind != InstructionKind::Reassign { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Invariant, + reason: "MemberExpression may only appear in an assignment expression" + .to_string(), + description: None, + loc: loc.clone(), + suggestions: None, + })?; + return Ok(None); + } + let object = lower_expression_to_temporary(builder, &member.object)?; + let temp = if !member.computed + || matches!( + &*member.property, + react_compiler_ast::expressions::Expression::NumericLiteral(_) + ) { + match &*member.property { + react_compiler_ast::expressions::Expression::Identifier(prop_id) => { + lower_value_to_temporary( + builder, + InstructionValue::PropertyStore { + object, + property: PropertyLiteral::String(prop_id.name.clone()), + value, + loc, + }, + )? + } + react_compiler_ast::expressions::Expression::NumericLiteral(num) => { + lower_value_to_temporary( + builder, + InstructionValue::PropertyStore { + object, + property: PropertyLiteral::Number(FloatValue::new(num.value)), + value, + loc, + }, + )? + } + _ => { + builder.record_error(CompilerErrorDetail { + reason: format!( + "(BuildHIR::lowerAssignment) Handle {} properties in \ + MemberExpression", + expression_type_name(&member.property) + ), + category: ErrorCategory::Todo, + loc: expression_loc(&member.property), + description: None, + suggestions: None, + })?; + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + )? + } + } + } else { + if matches!( + &*member.property, + react_compiler_ast::expressions::Expression::PrivateName(_) + ) { + builder.record_error(CompilerErrorDetail { + reason: "(BuildHIR::lowerAssignment) Expected private name to appear as a \ + non-computed property" + .to_string(), + category: ErrorCategory::Todo, + loc: expression_loc(&member.property), + description: None, + suggestions: None, + })?; + lower_value_to_temporary( + builder, + InstructionValue::UnsupportedNode { + node_type: None, + original_node: None, + loc, + }, + )? + } else { + let property_place = lower_expression_to_temporary(builder, &member.property)?; + lower_value_to_temporary( + builder, + InstructionValue::ComputedStore { + object, + property: property_place, + value, + loc, + }, + )? + } + }; + Ok(Some(temp)) + } + + PatternLike::ArrayPattern(pattern) => { + let mut items: Vec = Vec::new(); + let mut followups: Vec<(Place, &PatternLike)> = Vec::new(); + + // Compute forceTemporaries: when kind is Reassign and any element is + // non-identifier, a context variable, or a non-local binding + let force_temporaries = if kind == InstructionKind::Reassign { + let mut found = false; + for elem in &pattern.elements { + match elem { + Some(PatternLike::Identifier(id)) => { + let start = id.base.start.unwrap_or(0); + if builder.is_context_identifier(&id.name, start) { + found = true; + break; + } + let ident_loc = convert_opt_loc(&id.base.loc); + match builder.resolve_identifier(&id.name, start, ident_loc)? { + VariableBinding::Identifier { .. } => {} + _ => { + found = true; + break; + } + } + } + _ => { + // Non-identifier element (including None/holes) or RestElement + // Only non-None non-identifier elements trigger forceTemporaries + if elem.is_some() && !matches!(elem, Some(PatternLike::Identifier(_))) { + found = true; + break; + } + } + } + } + found + } else { + false + }; + + for element in &pattern.elements { + match element { + None => { + items.push(ArrayPatternElement::Hole); + } + Some(PatternLike::RestElement(rest)) => { + match &*rest.argument { + PatternLike::Identifier(id) => { + let start = id.base.start.unwrap_or(0); + let is_context = builder.is_context_identifier(&id.name, start); + let can_use_direct = !force_temporaries + && (matches!(assignment_style, AssignmentStyle::Assignment) + || !is_context); + if can_use_direct { + match lower_identifier_for_assignment( + builder, + convert_opt_loc(&rest.base.loc), + convert_opt_loc(&id.base.loc), + kind, + &id.name, + start, + )? { + Some(IdentifierForAssignment::Place(place)) => { + items.push(ArrayPatternElement::Spread( + SpreadPattern { place }, + )); + } + Some(IdentifierForAssignment::Global { .. }) => { + let temp = build_temporary_place( + builder, + convert_opt_loc(&rest.base.loc), + ); + promote_temporary(builder, temp.identifier); + items.push(ArrayPatternElement::Spread( + SpreadPattern { + place: temp.clone(), + }, + )); + followups.push((temp, &rest.argument)); + } + None => { + // Error already recorded + } + } + } else { + let temp = build_temporary_place( + builder, + convert_opt_loc(&rest.base.loc), + ); + promote_temporary(builder, temp.identifier); + items.push(ArrayPatternElement::Spread(SpreadPattern { + place: temp.clone(), + })); + followups.push((temp, &rest.argument)); + } + } + _ => { + let temp = + build_temporary_place(builder, convert_opt_loc(&rest.base.loc)); + promote_temporary(builder, temp.identifier); + items.push(ArrayPatternElement::Spread(SpreadPattern { + place: temp.clone(), + })); + followups.push((temp, &rest.argument)); + } + } + } + Some(PatternLike::Identifier(id)) => { + let start = id.base.start.unwrap_or(0); + let is_context = builder.is_context_identifier(&id.name, start); + let can_use_direct = !force_temporaries + && (matches!(assignment_style, AssignmentStyle::Assignment) + || !is_context); + if can_use_direct { + match lower_identifier_for_assignment( + builder, + convert_opt_loc(&id.base.loc), + convert_opt_loc(&id.base.loc), + kind, + &id.name, + start, + )? { + Some(IdentifierForAssignment::Place(place)) => { + items.push(ArrayPatternElement::Place(place)); + } + Some(IdentifierForAssignment::Global { .. }) => { + let temp = build_temporary_place( + builder, + convert_opt_loc(&id.base.loc), + ); + promote_temporary(builder, temp.identifier); + items.push(ArrayPatternElement::Place(temp.clone())); + followups.push((temp, element.as_ref().unwrap())); + } + None => { + items.push(ArrayPatternElement::Hole); + } + } + } else { + // Context variable or force_temporaries: use promoted temporary + let temp = + build_temporary_place(builder, convert_opt_loc(&id.base.loc)); + promote_temporary(builder, temp.identifier); + items.push(ArrayPatternElement::Place(temp.clone())); + followups.push((temp, element.as_ref().unwrap())); + } + } + Some(other) => { + // Nested pattern: use temporary + followup + let elem_loc = pattern_like_hir_loc(other); + let temp = build_temporary_place(builder, elem_loc); + promote_temporary(builder, temp.identifier); + items.push(ArrayPatternElement::Place(temp.clone())); + followups.push((temp, other)); + } + } + } + + let temporary = lower_value_to_temporary( + builder, + InstructionValue::Destructure { + lvalue: LValuePattern { + pattern: Pattern::Array(ArrayPattern { + items, + loc: convert_opt_loc(&pattern.base.loc), + }), + kind, + }, + value: value.clone(), + loc: loc.clone(), + }, + )?; + + for (place, path) in followups { + let followup_loc = pattern_like_hir_loc(path).or(loc.clone()); + lower_assignment(builder, followup_loc, kind, path, place, assignment_style)?; + } + Ok(Some(temporary)) + } + + PatternLike::ObjectPattern(pattern) => { + let mut properties: Vec = Vec::new(); + let mut followups: Vec<(Place, &PatternLike)> = Vec::new(); + + // Compute forceTemporaries for ObjectPattern + let force_temporaries = if kind == InstructionKind::Reassign { + use react_compiler_ast::patterns::ObjectPatternProperty; + let mut found = false; + for prop in &pattern.properties { + match prop { + ObjectPatternProperty::RestElement(_) => { + found = true; + break; + } + ObjectPatternProperty::ObjectProperty(obj_prop) => match &*obj_prop.value { + PatternLike::Identifier(id) => { + let start = id.base.start.unwrap_or(0); + let ident_loc = convert_opt_loc(&id.base.loc); + match builder.resolve_identifier(&id.name, start, ident_loc)? { + VariableBinding::Identifier { .. } => {} + _ => { + found = true; + break; + } + } + } + _ => { + found = true; + break; + } + }, + } + } + found + } else { + false + }; + + for prop in &pattern.properties { + match prop { + react_compiler_ast::patterns::ObjectPatternProperty::RestElement(rest) => { + match &*rest.argument { + PatternLike::Identifier(id) => { + let start = id.base.start.unwrap_or(0); + let is_context = builder.is_context_identifier(&id.name, start); + let can_use_direct = !force_temporaries + && (matches!(assignment_style, AssignmentStyle::Assignment) + || !is_context); + if can_use_direct { + match lower_identifier_for_assignment( + builder, + convert_opt_loc(&rest.base.loc), + convert_opt_loc(&id.base.loc), + kind, + &id.name, + start, + )? { + Some(IdentifierForAssignment::Place(place)) => { + properties.push(ObjectPropertyOrSpread::Spread( + SpreadPattern { place }, + )); + } + Some(IdentifierForAssignment::Global { .. }) => { + builder.record_error(CompilerErrorDetail { + reason: "Expected reassignment of globals to \ + enable forceTemporaries" + .to_string(), + category: ErrorCategory::Todo, + loc: convert_opt_loc(&rest.base.loc), + description: None, + suggestions: None, + })?; + } + None => {} + } + } else { + let temp = build_temporary_place( + builder, + convert_opt_loc(&rest.base.loc), + ); + promote_temporary(builder, temp.identifier); + properties.push(ObjectPropertyOrSpread::Spread( + SpreadPattern { + place: temp.clone(), + }, + )); + followups.push((temp, &rest.argument)); + } + } + _ => { + builder.record_error(CompilerErrorDetail { + reason: format!( + "(BuildHIR::lowerAssignment) Handle {} rest element in \ + ObjectPattern", + match &*rest.argument { + PatternLike::ObjectPattern(_) => "ObjectPattern", + PatternLike::ArrayPattern(_) => "ArrayPattern", + PatternLike::AssignmentPattern(_) => + "AssignmentPattern", + PatternLike::MemberExpression(_) => "MemberExpression", + _ => "unknown", + } + ), + category: ErrorCategory::Todo, + loc: convert_opt_loc(&rest.base.loc), + description: None, + suggestions: None, + })?; + } + } + } + react_compiler_ast::patterns::ObjectPatternProperty::ObjectProperty( + obj_prop, + ) => { + if obj_prop.computed { + builder.record_error(CompilerErrorDetail { + reason: "(BuildHIR::lowerAssignment) Handle computed properties \ + in ObjectPattern" + .to_string(), + category: ErrorCategory::Todo, + loc: convert_opt_loc(&obj_prop.base.loc), + description: None, + suggestions: None, + })?; + continue; + } + + let key = match lower_object_property_key(builder, &obj_prop.key, false)? { + Some(k) => k, + None => continue, + }; + + match &*obj_prop.value { + PatternLike::Identifier(id) => { + let start = id.base.start.unwrap_or(0); + let is_context = builder.is_context_identifier(&id.name, start); + let can_use_direct = !force_temporaries + && (matches!(assignment_style, AssignmentStyle::Assignment) + || !is_context); + if can_use_direct { + match lower_identifier_for_assignment( + builder, + convert_opt_loc(&id.base.loc), + convert_opt_loc(&id.base.loc), + kind, + &id.name, + start, + )? { + Some(IdentifierForAssignment::Place(place)) => { + properties.push(ObjectPropertyOrSpread::Property( + ObjectProperty { + key, + property_type: ObjectPropertyType::Property, + place, + }, + )); + } + Some(IdentifierForAssignment::Global { .. }) => { + builder.record_error(CompilerErrorDetail { + reason: "Expected reassignment of globals to \ + enable forceTemporaries" + .to_string(), + category: ErrorCategory::Todo, + loc: convert_opt_loc(&id.base.loc), + description: None, + suggestions: None, + })?; + } + None => { + continue; + } + } + } else { + // Context variable or force_temporaries: use promoted temporary + let temp = build_temporary_place( + builder, + convert_opt_loc(&id.base.loc), + ); + promote_temporary(builder, temp.identifier); + properties.push(ObjectPropertyOrSpread::Property( + ObjectProperty { + key, + property_type: ObjectPropertyType::Property, + place: temp.clone(), + }, + )); + followups.push((temp, &*obj_prop.value)); + } + } + other => { + // Nested pattern: use temporary + followup + let elem_loc = pattern_like_hir_loc(other); + let temp = build_temporary_place(builder, elem_loc); + promote_temporary(builder, temp.identifier); + properties.push(ObjectPropertyOrSpread::Property(ObjectProperty { + key, + property_type: ObjectPropertyType::Property, + place: temp.clone(), + })); + followups.push((temp, other)); + } + } + } + } + } + + let temporary = lower_value_to_temporary( + builder, + InstructionValue::Destructure { + lvalue: LValuePattern { + pattern: Pattern::Object(ObjectPattern { + properties, + loc: convert_opt_loc(&pattern.base.loc), + }), + kind, + }, + value: value.clone(), + loc: loc.clone(), + }, + )?; + + for (place, path) in followups { + let followup_loc = pattern_like_hir_loc(path).or(loc.clone()); + lower_assignment(builder, followup_loc, kind, path, place, assignment_style)?; + } + Ok(Some(temporary)) + } + + PatternLike::AssignmentPattern(pattern) => { + // Default value: if value === undefined, use default, else use value + let pat_loc = convert_opt_loc(&pattern.base.loc); + + let temp = build_temporary_place(builder, pat_loc.clone()); + + let test_block = builder.reserve(BlockKind::Value); + let continuation_block = builder.reserve(builder.current_block_kind()); + + // Consequent: use default value + let consequent = builder.try_enter(BlockKind::Value, |builder, _| { + let default_value = lower_reorderable_expression(builder, &pattern.right)?; + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + place: temp.clone(), + kind: InstructionKind::Const, + }, + value: default_value, + type_annotation: None, + loc: pat_loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_block.id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: pat_loc.clone(), + }) + }); + + // Alternate: use the original value + let alternate = builder.try_enter(BlockKind::Value, |builder, _| { + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + place: temp.clone(), + kind: InstructionKind::Const, + }, + value: value.clone(), + type_annotation: None, + loc: pat_loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_block.id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: pat_loc.clone(), + }) + }); + + // Ternary terminal + builder.terminate_with_continuation( + Terminal::Ternary { + test: test_block.id, + fallthrough: continuation_block.id, + id: EvaluationOrder(0), + loc: pat_loc.clone(), + }, + test_block, + ); + + // In test block: check if value === undefined + let undef = lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc: pat_loc.clone(), + }, + )?; + let test = lower_value_to_temporary( + builder, + InstructionValue::BinaryExpression { + left: value, + operator: BinaryOperator::StrictEqual, + right: undef, + loc: pat_loc.clone(), + }, + )?; + builder.terminate_with_continuation( + Terminal::Branch { + test, + consequent: consequent?, + alternate: alternate?, + fallthrough: continuation_block.id, + id: EvaluationOrder(0), + loc: pat_loc.clone(), + }, + continuation_block, + ); + + // Recursively assign the resolved value to the left pattern + Ok(lower_assignment( + builder, + pat_loc, + kind, + &pattern.left, + temp, + assignment_style, + )?) + } + + PatternLike::RestElement(rest) => { + // Delegate to the argument pattern + Ok(lower_assignment( + builder, + loc, + kind, + &rest.argument, + value, + assignment_style, + )?) + } + } +} + +/// Helper to extract HIR loc from a PatternLike (converts AST loc) +fn pattern_like_hir_loc(pat: &react_compiler_ast::patterns::PatternLike) -> Option { + convert_opt_loc(&pattern_like_loc(pat)) +} + +fn lower_optional_member_expression( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::OptionalMemberExpression, +) -> Result { + let place = lower_optional_member_expression_impl(builder, expr, None)?.1; + Ok(InstructionValue::LoadLocal { + loc: place.loc.clone(), + place, + }) +} + +/// Returns (object, value_place) pair. +/// The `value_place` is stored into a temporary; we also return it as an +/// InstructionValue via LoadLocal for the top-level call. +fn lower_optional_member_expression_impl( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::OptionalMemberExpression, + parent_alternate: Option, +) -> Result<(Place, Place), CompilerError> { + use react_compiler_ast::expressions::Expression; + let optional = expr.optional; + let loc = convert_opt_loc(&expr.base.loc); + let place = build_temporary_place(builder, loc.clone()); + let continuation_block = builder.reserve(builder.current_block_kind()); + let continuation_id = continuation_block.id; + let consequent = builder.reserve(BlockKind::Value); + + // Block to evaluate if the callee is null/undefined — sets result to undefined. + // Only create an alternate when first entering an optional subtree. + let alternate = if let Some(parent_alt) = parent_alternate { + Ok(parent_alt) + } else { + builder.try_enter(BlockKind::Value, |builder, _block_id| { + let temp = lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc: loc.clone(), + }, + )?; + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: temp, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + }) + }?; + + let mut object: Option = None; + let test_block = builder.try_enter(BlockKind::Value, |builder, _block_id| { + match expr.object.as_ref() { + Expression::OptionalMemberExpression(opt_member) => { + let (_obj, value) = + lower_optional_member_expression_impl(builder, opt_member, Some(alternate))?; + object = Some(value); + } + Expression::OptionalCallExpression(opt_call) => { + let value = + lower_optional_call_expression_impl(builder, opt_call, Some(alternate))?; + let value_place = lower_value_to_temporary(builder, value)?; + object = Some(value_place); + } + other => { + object = Some(lower_expression_to_temporary(builder, other)?); + } + } + let test_place = object.as_ref().unwrap().clone(); + Ok(Terminal::Branch { + test: test_place, + consequent: consequent.id, + alternate, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + }); + + let obj = object.unwrap(); + + // Block to evaluate if the callee is non-null/undefined + builder.try_enter_reserved(consequent, |builder| { + let lowered = lower_member_expression_with_object(builder, expr, obj.clone())?; + let temp = lower_value_to_temporary(builder, lowered.value)?; + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: temp, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + })?; + + builder.terminate_with_continuation( + Terminal::Optional { + optional, + test: test_block?, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + + Ok((obj, place)) +} + +fn lower_optional_call_expression( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::OptionalCallExpression, +) -> Result { + Ok(lower_optional_call_expression_impl(builder, expr, None)?) +} + +fn lower_optional_call_expression_impl( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::OptionalCallExpression, + parent_alternate: Option, +) -> Result { + use react_compiler_ast::expressions::Expression; + let optional = expr.optional; + let loc = convert_opt_loc(&expr.base.loc); + let place = build_temporary_place(builder, loc.clone()); + let continuation_block = builder.reserve(builder.current_block_kind()); + let continuation_id = continuation_block.id; + let consequent = builder.reserve(BlockKind::Value); + + // Block to evaluate if the callee is null/undefined + let alternate = if let Some(parent_alt) = parent_alternate { + Ok(parent_alt) + } else { + builder.try_enter(BlockKind::Value, |builder, _block_id| { + let temp = lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc: loc.clone(), + }, + )?; + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: temp, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + }) + }?; + + // Track callee info for building the call in the consequent block + enum CalleeInfo { + CallExpression { callee: Place }, + MethodCall { receiver: Place, property: Place }, + } + + let mut callee_info: Option = None; + + let test_block = builder.try_enter(BlockKind::Value, |builder, _block_id| { + match expr.callee.as_ref() { + Expression::OptionalCallExpression(opt_call) => { + let value = + lower_optional_call_expression_impl(builder, opt_call, Some(alternate))?; + let value_place = lower_value_to_temporary(builder, value)?; + callee_info = Some(CalleeInfo::CallExpression { + callee: value_place, + }); + } + Expression::OptionalMemberExpression(opt_member) => { + let (obj, value) = + lower_optional_member_expression_impl(builder, opt_member, Some(alternate))?; + callee_info = Some(CalleeInfo::MethodCall { + receiver: obj, + property: value, + }); + } + Expression::MemberExpression(member) => { + let lowered = lower_member_expression(builder, member)?; + let property_place = lower_value_to_temporary(builder, lowered.value)?; + callee_info = Some(CalleeInfo::MethodCall { + receiver: lowered.object, + property: property_place, + }); + } + other => { + let callee_place = lower_expression_to_temporary(builder, other)?; + callee_info = Some(CalleeInfo::CallExpression { + callee: callee_place, + }); + } + } + + let test_place = match callee_info.as_ref().unwrap() { + CalleeInfo::CallExpression { callee } => callee.clone(), + CalleeInfo::MethodCall { property, .. } => property.clone(), + }; + + Ok(Terminal::Branch { + test: test_place, + consequent: consequent.id, + alternate, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + }); + + // Block to evaluate if the callee is non-null/undefined + builder.try_enter_reserved(consequent, |builder| { + let args = lower_arguments(builder, &expr.arguments)?; + let temp = build_temporary_place(builder, loc.clone()); + + match callee_info.as_ref().unwrap() { + CalleeInfo::CallExpression { callee } => { + builder.push(Instruction { + id: EvaluationOrder(0), + lvalue: temp.clone(), + value: InstructionValue::CallExpression { + callee: callee.clone(), + args, + loc: loc.clone(), + }, + loc: loc.clone(), + effects: None, + }); + } + CalleeInfo::MethodCall { receiver, property } => { + builder.push(Instruction { + id: EvaluationOrder(0), + lvalue: temp.clone(), + value: InstructionValue::MethodCall { + receiver: receiver.clone(), + property: property.clone(), + args, + loc: loc.clone(), + }, + loc: loc.clone(), + effects: None, + }); + } + } + + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Const, + place: place.clone(), + }, + value: temp, + type_annotation: None, + loc: loc.clone(), + }, + )?; + Ok(Terminal::Goto { + block: continuation_id, + variant: GotoVariant::Break, + id: EvaluationOrder(0), + loc: loc.clone(), + }) + })?; + + builder.terminate_with_continuation( + Terminal::Optional { + optional, + test: test_block?, + fallthrough: continuation_id, + id: EvaluationOrder(0), + loc: loc.clone(), + }, + continuation_block, + ); + + Ok(InstructionValue::LoadLocal { + place: place.clone(), + loc: place.loc, + }) +} + +fn lower_function_to_value( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::Expression, + expr_type: FunctionExpressionType, +) -> Result { + use react_compiler_ast::expressions::Expression; + let loc = match expr { + Expression::ArrowFunctionExpression(arrow) => convert_opt_loc(&arrow.base.loc), + Expression::FunctionExpression(func) => convert_opt_loc(&func.base.loc), + _ => None, + }; + let name = match expr { + Expression::FunctionExpression(func) => func.id.as_ref().map(|id| id.name.clone()), + _ => None, + }; + let lowered_func = lower_function(builder, expr)?; + Ok(InstructionValue::FunctionExpression { + name, + name_hint: None, + lowered_func, + expr_type, + loc, + }) +} + +fn lower_function( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::Expression, +) -> Result { + use react_compiler_ast::expressions::Expression; + + // Extract function parts from the AST node + let (params, body, id, generator, is_async, func_start, func_end, func_loc) = match expr { + Expression::ArrowFunctionExpression(arrow) => { + let body = match arrow.body.as_ref() { + react_compiler_ast::expressions::ArrowFunctionBody::BlockStatement(block) => { + FunctionBody::Block(block) + } + react_compiler_ast::expressions::ArrowFunctionBody::Expression(expr) => { + FunctionBody::Expression(expr) + } + }; + ( + &arrow.params[..], + body, + None::<&str>, + arrow.generator, + arrow.is_async, + arrow.base.start.unwrap_or(0), + arrow.base.end.unwrap_or(0), + convert_opt_loc(&arrow.base.loc), + ) + } + Expression::FunctionExpression(func) => ( + &func.params[..], + FunctionBody::Block(&func.body), + func.id.as_ref().map(|id| id.name.as_str()), + func.generator, + func.is_async, + func.base.start.unwrap_or(0), + func.base.end.unwrap_or(0), + convert_opt_loc(&func.base.loc), + ), + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "lower_function called with non-function expression", + None, + )); + } + }; + + // Find the function's scope + let function_scope = builder + .scope_info() + .node_to_scope + .get(&func_start) + .copied() + .unwrap_or(builder.scope_info().program_scope); + + let component_scope = builder.component_scope(); + let scope_info = builder.scope_info(); + + // Clone parent bindings and used_names to pass to the inner lower + let parent_bindings = builder.bindings().clone(); + let parent_used_names = builder.used_names().clone(); + let context_ids = builder.context_identifiers().clone(); + let ident_locs = builder.identifier_locs(); + + // Gather captured context + let captured_context = gather_captured_context( + scope_info, + function_scope, + component_scope, + func_start, + func_end, + ident_locs, + ); + + // Merge parent context with captured context. + // The locally-gathered captured context overrides the parent's loc values, + // matching the TS behavior: `new Map([...builder.context, ...capturedContext])` + // where later entries win. + let merged_context: IndexMap> = { + let parent_context = builder.context().clone(); + let mut merged = parent_context; + for (k, v) in captured_context { + merged.insert(k, v); + } + merged + }; + + // Use scope_info_and_env_mut to avoid conflicting borrows + let (scope_info, env) = builder.scope_info_and_env_mut(); + let (hir_func, child_used_names, child_bindings) = lower_inner( + params, + body, + id, + generator, + is_async, + func_loc, + scope_info, + env, + Some(parent_bindings), + Some(parent_used_names), + merged_context, + function_scope, + component_scope, + &context_ids, + false, // nested function + ident_locs, + )?; + + // Merge the child's used_names and bindings back into the parent builder. + // This ensures name deduplication works across function scopes, + // matching the TS behavior where #bindings is shared by reference. + builder.merge_used_names(child_used_names); + builder.merge_bindings(child_bindings); + + let func_id = builder.environment_mut().add_function(hir_func); + Ok(LoweredFunction { func: func_id }) +} + +/// Lower a function declaration statement to a FunctionExpression + StoreLocal. +fn lower_function_declaration( + builder: &mut HirBuilder, + func_decl: &react_compiler_ast::statements::FunctionDeclaration, +) -> Result<(), CompilerError> { + let loc = convert_opt_loc(&func_decl.base.loc); + let func_start = func_decl.base.start.unwrap_or(0); + let func_end = func_decl.base.end.unwrap_or(0); + + let func_name = func_decl.id.as_ref().map(|id| id.name.clone()); + + // Find the function's scope + let function_scope = builder + .scope_info() + .node_to_scope + .get(&func_start) + .copied() + .unwrap_or(builder.scope_info().program_scope); + + let component_scope = builder.component_scope(); + let scope_info = builder.scope_info(); + + let parent_bindings = builder.bindings().clone(); + let parent_used_names = builder.used_names().clone(); + let context_ids = builder.context_identifiers().clone(); + let ident_locs = builder.identifier_locs(); + + // Gather captured context + let captured_context = gather_captured_context( + scope_info, + function_scope, + component_scope, + func_start, + func_end, + ident_locs, + ); + + // Merge parent context with captured context. + // The locally-gathered captured context overrides the parent's loc values, + // matching the TS behavior: `new Map([...builder.context, ...capturedContext])` + let merged_context: IndexMap> = { + let parent_context = builder.context().clone(); + let mut merged = parent_context; + for (k, v) in captured_context { + merged.insert(k, v); + } + merged + }; + + let (scope_info, env) = builder.scope_info_and_env_mut(); + let (hir_func, child_used_names, child_bindings) = lower_inner( + &func_decl.params, + FunctionBody::Block(&func_decl.body), + func_decl.id.as_ref().map(|id| id.name.as_str()), + func_decl.generator, + func_decl.is_async, + loc.clone(), + scope_info, + env, + Some(parent_bindings), + Some(parent_used_names), + merged_context, + function_scope, + component_scope, + &context_ids, + false, // nested function + ident_locs, + )?; + + builder.merge_used_names(child_used_names); + // Merge child bindings so the parent can reuse the same IdentifierIds + // for bindings that were already resolved by the child. This matches TS + // behavior where the parent and child share the same #bindings map by + // reference. + builder.merge_bindings(child_bindings); + + let func_id = builder.environment_mut().add_function(hir_func); + let lowered_func = LoweredFunction { func: func_id }; + + // Emit FunctionExpression instruction + let fn_value = InstructionValue::FunctionExpression { + name: func_name.clone(), + name_hint: None, + lowered_func, + expr_type: FunctionExpressionType::FunctionDeclaration, + loc: loc.clone(), + }; + let fn_place = lower_value_to_temporary(builder, fn_value)?; + + // Resolve the binding for the function name and store. + // Note: we must resolve from the function's INNER scope, not using + // reference_to_binding directly. This matches TS behavior where Babel's + // `path.scope.getBinding()` resolves from the function declaration's inner + // scope. If there's an inner variable that shadows the function name (e.g., + // `function hasErrors() { let hasErrors = ... }`), Babel's scope resolution + // finds the inner binding, not the outer function binding. + if let Some(ref name) = func_name { + if let Some(id_node) = &func_decl.id { + let start = id_node.base.start.unwrap_or(0); + let ident_loc = convert_opt_loc(&id_node.base.loc); + // Look up the binding from the function's inner scope, which may shadow + // the outer binding with the same name + let inner_binding_id = builder.scope_info().get_binding(function_scope, name); + let binding = if let Some(inner_bid) = inner_binding_id { + let binding_kind = crate::convert_binding_kind( + &builder.scope_info().bindings[inner_bid.0 as usize].kind, + ); + let identifier_id = + builder.resolve_binding_with_loc(name, inner_bid, ident_loc.clone())?; + VariableBinding::Identifier { + identifier: identifier_id, + binding_kind, + } + } else { + builder.resolve_identifier(name, start, ident_loc.clone())? + }; + match binding { + VariableBinding::Identifier { identifier, .. } => { + // Don't override the identifier's declaration loc here. + // For function redeclarations (e.g., `function x() {} function x() {}`), + // the identifier's loc should remain the first declaration's loc, + // which was already set during define_binding. + // Use the full function declaration loc for the Place, + // matching the TS behavior where lowerAssignment uses stmt.node.loc + let place = Place { + identifier, + reactive: false, + effect: Effect::Unknown, + loc: loc.clone(), + }; + if builder.is_context_identifier(name, start) { + lower_value_to_temporary( + builder, + InstructionValue::StoreContext { + lvalue: LValue { + kind: InstructionKind::Function, + place, + }, + value: fn_place, + loc, + }, + )?; + } else { + lower_value_to_temporary( + builder, + InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Function, + place, + }, + value: fn_place, + type_annotation: None, + loc, + }, + )?; + } + } + _ => { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Invariant, + reason: format!( + "Could not find binding for function declaration `{}`", + name + ), + description: None, + loc, + suggestions: None, + })?; + } + } + } + } + Ok(()) +} + +/// Lower a function expression used as an object method. +fn lower_function_for_object_method( + builder: &mut HirBuilder, + method: &react_compiler_ast::expressions::ObjectMethod, +) -> Result { + let func_start = method.base.start.unwrap_or(0); + let func_end = method.base.end.unwrap_or(0); + let func_loc = convert_opt_loc(&method.base.loc); + + let function_scope = builder + .scope_info() + .node_to_scope + .get(&func_start) + .copied() + .unwrap_or(builder.scope_info().program_scope); + + let component_scope = builder.component_scope(); + let scope_info = builder.scope_info(); + + let parent_bindings = builder.bindings().clone(); + let parent_used_names = builder.used_names().clone(); + let context_ids = builder.context_identifiers().clone(); + let ident_locs = builder.identifier_locs(); + + let captured_context = gather_captured_context( + scope_info, + function_scope, + component_scope, + func_start, + func_end, + ident_locs, + ); + + // Merge parent context with captured context. + // The locally-gathered captured context overrides the parent's loc values, + // matching the TS behavior: `new Map([...builder.context, ...capturedContext])` + let merged_context: IndexMap> = { + let parent_context = builder.context().clone(); + let mut merged = parent_context; + for (k, v) in captured_context { + merged.insert(k, v); + } + merged + }; + + let (scope_info, env) = builder.scope_info_and_env_mut(); + let (hir_func, child_used_names, child_bindings) = lower_inner( + &method.params, + FunctionBody::Block(&method.body), + None, + method.generator, + method.is_async, + func_loc, + scope_info, + env, + Some(parent_bindings), + Some(parent_used_names), + merged_context, + function_scope, + component_scope, + &context_ids, + false, // nested function + ident_locs, + )?; + + builder.merge_used_names(child_used_names); + builder.merge_bindings(child_bindings); + + let func_id = builder.environment_mut().add_function(hir_func); + Ok(LoweredFunction { func: func_id }) +} + +/// Internal helper: lower a function given its extracted parts. +/// Used by both the top-level `lower()` and nested `lower_function()`. +fn lower_inner( + params: &[react_compiler_ast::patterns::PatternLike], + body: FunctionBody<'_>, + id: Option<&str>, + generator: bool, + is_async: bool, + loc: Option, + scope_info: &ScopeInfo, + env: &mut Environment, + parent_bindings: Option>, + parent_used_names: Option>, + context_map: IndexMap>, + function_scope: react_compiler_ast::scope::ScopeId, + component_scope: react_compiler_ast::scope::ScopeId, + context_identifiers: &HashSet, + is_top_level: bool, + identifier_locs: &IdentifierLocIndex, +) -> Result< + ( + HirFunction, + IndexMap, + IndexMap, + ), + CompilerError, +> { + let mut builder = HirBuilder::new( + env, + scope_info, + function_scope, + component_scope, + context_identifiers.clone(), + parent_bindings, + Some(context_map.clone()), + None, + parent_used_names, + identifier_locs, + ); + + // Build context places from the captured refs + let mut context: Vec = Vec::new(); + for (&binding_id, ctx_loc) in &context_map { + let binding = &scope_info.bindings[binding_id.0 as usize]; + let identifier = builder.resolve_binding(&binding.name, binding_id)?; + context.push(Place { + identifier, + effect: Effect::Unknown, + reactive: false, + loc: ctx_loc.clone(), + }); + } + + // Process parameters + let mut hir_params: Vec = Vec::new(); + for param in params { + match param { + react_compiler_ast::patterns::PatternLike::Identifier(ident) => { + let start = ident.base.start.unwrap_or(0); + let param_loc = convert_opt_loc(&ident.base.loc); + let binding = builder.resolve_identifier(&ident.name, start, param_loc.clone())?; + match binding { + VariableBinding::Identifier { identifier, .. } => { + // Set the identifier's loc from the declaration (param) site + builder.set_identifier_declaration_loc(identifier, ¶m_loc); + let place = Place { + identifier, + effect: Effect::Unknown, + reactive: false, + loc: param_loc, + }; + hir_params.push(ParamPattern::Place(place)); + } + _ => { + builder.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Could not find binding", + Some(format!( + "[BuildHIR] Could not find binding for param `{}`", + ident.name + )), + ) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: convert_opt_loc(&ident.base.loc), + message: Some("Could not find binding".to_string()), + identifier_name: None, + }, + ), + ); + } + } + } + react_compiler_ast::patterns::PatternLike::RestElement(rest) => { + let rest_loc = convert_opt_loc(&rest.base.loc); + // Create a temporary place for the spread param + let place = build_temporary_place(&mut builder, rest_loc.clone()); + hir_params.push(ParamPattern::Spread(SpreadPattern { + place: place.clone(), + })); + // Delegate the assignment of the rest argument + lower_assignment( + &mut builder, + rest_loc, + InstructionKind::Let, + &rest.argument, + place, + AssignmentStyle::Assignment, + )?; + } + react_compiler_ast::patterns::PatternLike::ObjectPattern(_) + | react_compiler_ast::patterns::PatternLike::ArrayPattern(_) + | react_compiler_ast::patterns::PatternLike::AssignmentPattern(_) => { + let param_loc = convert_opt_loc(&pattern_like_loc(param)); + let place = build_temporary_place(&mut builder, param_loc.clone()); + promote_temporary(&mut builder, place.identifier); + hir_params.push(ParamPattern::Place(place.clone())); + lower_assignment( + &mut builder, + param_loc, + InstructionKind::Let, + param, + place, + AssignmentStyle::Assignment, + )?; + } + react_compiler_ast::patterns::PatternLike::MemberExpression(member) => { + builder.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Todo, + "Handle MemberExpression parameters", + Some("[BuildHIR] Add support for MemberExpression parameters".to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: convert_opt_loc(&member.base.loc), + message: Some("Unsupported parameter type".to_string()), + identifier_name: None, + }), + ); + } + } + } + + // Lower the body + let mut directives: Vec = Vec::new(); + match body { + FunctionBody::Expression(expr) => { + let fallthrough = builder.reserve(BlockKind::Block); + let value = lower_expression_to_temporary(&mut builder, expr)?; + builder.terminate_with_continuation( + Terminal::Return { + value, + return_variant: ReturnVariant::Implicit, + id: EvaluationOrder(0), + loc: None, + effects: None, + }, + fallthrough, + ); + } + FunctionBody::Block(block) => { + directives = block + .directives + .iter() + .map(|d| d.value.value.clone()) + .collect(); + // Use lower_block_statement_with_scope to get hoisting support for the function + // body. Pass the function scope since in Babel, a function body + // BlockStatement shares the function's scope (node_to_scope maps + // the function node, not the block). + lower_block_statement_with_scope(&mut builder, block, function_scope)?; + } + } + + // Emit final Return(Void, undefined) + let undefined_value = InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc: None, + }; + let return_value = lower_value_to_temporary(&mut builder, undefined_value)?; + builder.terminate( + Terminal::Return { + value: return_value, + return_variant: ReturnVariant::Void, + id: EvaluationOrder(0), + loc: None, + effects: None, + }, + None, + ); + + // Build the HIR + let (hir_body, instructions, used_names, child_bindings) = builder.build()?; + + // Create the returns place + let returns = crate::hir_builder::create_temporary_place(env, loc.clone()); + + Ok(( + HirFunction { + loc, + id: id.map(|s| s.to_string()), + name_hint: None, + fn_type: if is_top_level { + env.fn_type + } else { + ReactFunctionType::Other + }, + params: hir_params, + return_type_annotation: None, + returns, + context, + body: hir_body, + instructions, + generator, + is_async, + directives, + aliasing_effects: None, + }, + used_names, + child_bindings, + )) +} + +fn lower_jsx_element_name( + builder: &mut HirBuilder, + name: &react_compiler_ast::jsx::JSXElementName, +) -> Result { + use react_compiler_ast::jsx::JSXElementName; + match name { + JSXElementName::JSXIdentifier(id) => { + let tag = &id.name; + let loc = convert_opt_loc(&id.base.loc); + let start = id.base.start.unwrap_or(0); + if tag.starts_with(|c: char| c.is_ascii_uppercase()) { + // Component tag: resolve as identifier and load + let place = lower_identifier(builder, tag, start, loc.clone())?; + let load_value = if builder.is_context_identifier(tag, start) { + InstructionValue::LoadContext { place, loc } + } else { + InstructionValue::LoadLocal { place, loc } + }; + let temp = lower_value_to_temporary(builder, load_value)?; + Ok(JsxTag::Place(temp)) + } else { + // Builtin HTML tag + Ok(JsxTag::Builtin(BuiltinTag { + name: tag.clone(), + loc, + })) + } + } + JSXElementName::JSXMemberExpression(member) => { + let place = lower_jsx_member_expression(builder, member)?; + Ok(JsxTag::Place(place)) + } + JSXElementName::JSXNamespacedName(ns) => { + let namespace = &ns.namespace.name; + let name = &ns.name.name; + let tag = format!("{}:{}", namespace, name); + let loc = convert_opt_loc(&ns.base.loc); + if namespace.contains(':') || name.contains(':') { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Syntax, + reason: "Expected JSXNamespacedName to have no colons in the namespace or name" + .to_string(), + description: Some(format!("Got `{}` : `{}`", namespace, name)), + loc: loc.clone(), + suggestions: None, + })?; + } + let place = lower_value_to_temporary( + builder, + InstructionValue::Primitive { + value: PrimitiveValue::String(tag), + loc: loc.clone(), + }, + )?; + Ok(JsxTag::Place(place)) + } + } +} + +fn lower_jsx_member_expression( + builder: &mut HirBuilder, + expr: &react_compiler_ast::jsx::JSXMemberExpression, +) -> Result { + use react_compiler_ast::jsx::JSXMemberExprObject; + // Use the full member expression's loc for instruction locs (matching TS: + // exprPath.node.loc) + let expr_loc = convert_opt_loc(&expr.base.loc); + let object = match &*expr.object { + JSXMemberExprObject::JSXIdentifier(id) => { + let id_loc = convert_opt_loc(&id.base.loc); + let start = id.base.start.unwrap_or(0); + // Use identifier's own loc for the place, but member expression's loc for the + // instruction + let place = lower_identifier(builder, &id.name, start, id_loc)?; + let load_value = if builder.is_context_identifier(&id.name, start) { + InstructionValue::LoadContext { + place, + loc: expr_loc.clone(), + } + } else { + InstructionValue::LoadLocal { + place, + loc: expr_loc.clone(), + } + }; + lower_value_to_temporary(builder, load_value)? + } + JSXMemberExprObject::JSXMemberExpression(inner) => { + lower_jsx_member_expression(builder, inner)? + } + }; + let prop_name = &expr.property.name; + let value = InstructionValue::PropertyLoad { + object, + property: PropertyLiteral::String(prop_name.clone()), + loc: expr_loc, + }; + Ok(lower_value_to_temporary(builder, value)?) +} + +fn lower_jsx_element( + builder: &mut HirBuilder, + child: &react_compiler_ast::jsx::JSXChild, +) -> Result, CompilerError> { + use react_compiler_ast::jsx::{JSXChild, JSXExpressionContainerExpr}; + match child { + JSXChild::JSXText(text) => { + // FBT whitespace normalization differs from standard JSX. + // Since the fbt transform runs after, preserve all whitespace + // in FBT subtrees as is. + let value = if builder.fbt_depth > 0 { + Some(text.value.clone()) + } else { + trim_jsx_text(&text.value) + }; + match value { + None => Ok(None), + Some(value) => { + let loc = convert_opt_loc(&text.base.loc); + let place = lower_value_to_temporary( + builder, + InstructionValue::JSXText { value, loc }, + )?; + Ok(Some(place)) + } + } + } + JSXChild::JSXElement(element) => { + let value = lower_expression( + builder, + &react_compiler_ast::expressions::Expression::JSXElement(element.clone()), + )?; + Ok(Some(lower_value_to_temporary(builder, value)?)) + } + JSXChild::JSXFragment(fragment) => { + let value = lower_expression( + builder, + &react_compiler_ast::expressions::Expression::JSXFragment(fragment.clone()), + )?; + Ok(Some(lower_value_to_temporary(builder, value)?)) + } + JSXChild::JSXExpressionContainer(container) => match &container.expression { + JSXExpressionContainerExpr::JSXEmptyExpression(_) => Ok(None), + JSXExpressionContainerExpr::Expression(expr) => { + Ok(Some(lower_expression_to_temporary(builder, expr)?)) + } + }, + JSXChild::JSXSpreadChild(spread) => Ok(Some(lower_expression_to_temporary( + builder, + &spread.expression, + )?)), + } +} + +/// Split a string on line endings, handling \r\n, \n, and \r. +fn split_line_endings(s: &str) -> Vec<&str> { + let mut lines = Vec::new(); + let mut start = 0; + let bytes = s.as_bytes(); + let mut i = 0; + while i < bytes.len() { + if bytes[i] == b'\r' { + lines.push(&s[start..i]); + if i + 1 < bytes.len() && bytes[i + 1] == b'\n' { + i += 2; + } else { + i += 1; + } + start = i; + } else if bytes[i] == b'\n' { + lines.push(&s[start..i]); + i += 1; + start = i; + } else { + i += 1; + } + } + lines.push(&s[start..]); + lines +} + +/// Trims whitespace according to the JSX spec. +/// Implementation ported from Babel's cleanJSXElementLiteralChild. +fn trim_jsx_text(original: &str) -> Option { + // Split on \r\n, \n, or \r to handle all line ending styles (matching TS + // split(/\r\n|\n|\r/)) + let lines: Vec<&str> = split_line_endings(original); + + // NOTE: when builder.fbt_depth > 0, the TS skips whitespace trimming entirely. + // That check is handled by the caller (lower_jsx_element) before calling this + // function. + + let mut last_non_empty_line = 0; + for (i, line) in lines.iter().enumerate() { + if line.contains(|c: char| c != ' ' && c != '\t') { + last_non_empty_line = i; + } + } + + let mut str = String::new(); + + for (i, line) in lines.iter().enumerate() { + let is_first_line = i == 0; + let is_last_line = i == lines.len() - 1; + let is_last_non_empty_line = i == last_non_empty_line; + + // Replace rendered whitespace tabs with spaces + let mut trimmed_line = line.replace('\t', " "); + + // Trim whitespace touching a newline (leading whitespace on non-first lines) + if !is_first_line { + trimmed_line = trimmed_line.trim_start_matches(' ').to_string(); + } + + // Trim whitespace touching an endline (trailing whitespace on non-last lines) + if !is_last_line { + trimmed_line = trimmed_line.trim_end_matches(' ').to_string(); + } + + if !trimmed_line.is_empty() { + if !is_last_non_empty_line { + trimmed_line.push(' '); + } + str.push_str(&trimmed_line); + } + } + + if str.is_empty() { + None + } else { + Some(str) + } +} + +fn lower_object_method( + builder: &mut HirBuilder, + method: &react_compiler_ast::expressions::ObjectMethod, +) -> Result, CompilerError> { + use react_compiler_ast::expressions::ObjectMethodKind; + if !matches!(method.kind, ObjectMethodKind::Method) { + let kind_str = match method.kind { + ObjectMethodKind::Get => "get", + ObjectMethodKind::Set => "set", + ObjectMethodKind::Method => "method", + }; + builder.record_error(CompilerErrorDetail { + reason: format!( + "(BuildHIR::lowerExpression) Handle {} functions in ObjectExpression", + kind_str + ), + category: ErrorCategory::Todo, + loc: convert_opt_loc(&method.base.loc), + description: None, + suggestions: None, + })?; + return Ok(None); + } + let key = lower_object_property_key(builder, &method.key, method.computed)?.unwrap_or( + ObjectPropertyKey::String { + name: String::new(), + }, + ); + + let lowered_func = lower_function_for_object_method(builder, method)?; + + let loc = convert_opt_loc(&method.base.loc); + let method_value = InstructionValue::ObjectMethod { + loc: loc.clone(), + lowered_func, + }; + let method_place = lower_value_to_temporary(builder, method_value)?; + + Ok(Some(ObjectProperty { + key, + property_type: ObjectPropertyType::Method, + place: method_place, + })) +} + +fn lower_object_property_key( + builder: &mut HirBuilder, + key: &react_compiler_ast::expressions::Expression, + computed: bool, +) -> Result, CompilerError> { + use react_compiler_ast::expressions::Expression; + match key { + Expression::StringLiteral(lit) => Ok(Some(ObjectPropertyKey::String { + name: lit.value.clone(), + })), + Expression::Identifier(ident) if !computed => Ok(Some(ObjectPropertyKey::Identifier { + name: ident.name.clone(), + })), + Expression::NumericLiteral(lit) if !computed => Ok(Some(ObjectPropertyKey::Identifier { + name: lit.value.to_string(), + })), + _ if computed => { + let place = lower_expression_to_temporary(builder, key)?; + Ok(Some(ObjectPropertyKey::Computed { name: place })) + } + _ => { + let loc = match key { + Expression::Identifier(i) => convert_opt_loc(&i.base.loc), + _ => None, + }; + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Unsupported key type in ObjectExpression".to_string(), + description: None, + loc, + suggestions: None, + })?; + Ok(None) + } + } +} + +fn lower_reorderable_expression( + builder: &mut HirBuilder, + expr: &react_compiler_ast::expressions::Expression, +) -> Result { + if !is_reorderable_expression(builder, expr, true) { + builder.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!( + "(BuildHIR::node.lowerReorderableExpression) Expression type `{}` cannot be \ + safely reordered", + expression_type_name(expr) + ), + description: None, + loc: expression_loc(expr), + suggestions: None, + })?; + } + Ok(lower_expression_to_temporary(builder, expr)?) +} + +fn is_reorderable_expression( + builder: &HirBuilder, + expr: &react_compiler_ast::expressions::Expression, + allow_local_identifiers: bool, +) -> bool { + use react_compiler_ast::expressions::Expression; + match expr { + Expression::Identifier(ident) => { + let start = ident.base.start.unwrap_or(0); + let binding = builder.scope_info().resolve_reference(start); + match binding { + None => { + // global, safe to reorder + true + } + Some(b) => { + if b.scope == builder.scope_info().program_scope { + // Module-scope binding (ModuleLocal, imports), safe to reorder + true + } else { + allow_local_identifiers + } + } + } + } + Expression::RegExpLiteral(_) + | Expression::StringLiteral(_) + | Expression::NumericLiteral(_) + | Expression::NullLiteral(_) + | Expression::BooleanLiteral(_) + | Expression::BigIntLiteral(_) => true, + Expression::UnaryExpression(unary) => { + use react_compiler_ast::operators::UnaryOperator; + matches!( + unary.operator, + UnaryOperator::Not | UnaryOperator::Plus | UnaryOperator::Neg + ) && is_reorderable_expression(builder, &unary.argument, allow_local_identifiers) + } + Expression::LogicalExpression(logical) => { + is_reorderable_expression(builder, &logical.left, allow_local_identifiers) + && is_reorderable_expression(builder, &logical.right, allow_local_identifiers) + } + Expression::ConditionalExpression(cond) => { + is_reorderable_expression(builder, &cond.test, allow_local_identifiers) + && is_reorderable_expression(builder, &cond.consequent, allow_local_identifiers) + && is_reorderable_expression(builder, &cond.alternate, allow_local_identifiers) + } + Expression::ArrayExpression(arr) => { + arr.elements.iter().all(|element| { + match element { + Some(e) => is_reorderable_expression(builder, e, allow_local_identifiers), + None => false, // holes are not reorderable + } + }) + } + Expression::ObjectExpression(obj) => obj.properties.iter().all(|prop| match prop { + react_compiler_ast::expressions::ObjectExpressionProperty::ObjectProperty(p) => { + !p.computed && is_reorderable_expression(builder, &p.value, allow_local_identifiers) + } + _ => false, + }), + Expression::MemberExpression(member) => { + // Allow member expressions where the innermost object is a global or + // module-local + let mut inner = member.object.as_ref(); + while let Expression::MemberExpression(m) = inner { + inner = m.object.as_ref(); + } + if let Expression::Identifier(ident) = inner { + let start = ident.base.start.unwrap_or(0); + match builder.scope_info().resolve_reference(start) { + None => true, // global + Some(binding) => { + // Module-scope bindings (ModuleLocal, imports) are safe to reorder + binding.scope == builder.scope_info().program_scope + } + } + } else { + false + } + } + Expression::ArrowFunctionExpression(arrow) => { + use react_compiler_ast::expressions::ArrowFunctionBody; + match arrow.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => block.body.is_empty(), + ArrowFunctionBody::Expression(body_expr) => { + is_reorderable_expression(builder, body_expr, false) + } + } + } + Expression::CallExpression(call) => { + is_reorderable_expression(builder, &call.callee, allow_local_identifiers) + && call + .arguments + .iter() + .all(|arg| is_reorderable_expression(builder, arg, allow_local_identifiers)) + } + // TypeScript/Flow type wrappers: recurse into the inner expression + Expression::TSAsExpression(ts) => { + is_reorderable_expression(builder, &ts.expression, allow_local_identifiers) + } + Expression::TSSatisfiesExpression(ts) => { + is_reorderable_expression(builder, &ts.expression, allow_local_identifiers) + } + Expression::TSNonNullExpression(ts) => { + is_reorderable_expression(builder, &ts.expression, allow_local_identifiers) + } + Expression::TSInstantiationExpression(ts) => { + is_reorderable_expression(builder, &ts.expression, allow_local_identifiers) + } + Expression::TypeCastExpression(tc) => { + is_reorderable_expression(builder, &tc.expression, allow_local_identifiers) + } + Expression::TSTypeAssertion(ts) => { + is_reorderable_expression(builder, &ts.expression, allow_local_identifiers) + } + Expression::ParenthesizedExpression(p) => { + is_reorderable_expression(builder, &p.expression, allow_local_identifiers) + } + _ => false, + } +} + +/// Extract the type name from a type annotation serde_json::Value. +/// Returns the "type" field value, e.g. "TSTypeReference", +/// "GenericTypeAnnotation". +fn get_type_annotation_name(val: &serde_json::Value) -> Option { + val.get("type") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) +} + +/// Lower a type annotation JSON value to an HIR Type. +/// Mirrors the TS `lowerType` function. +fn lower_type_annotation(val: &serde_json::Value, builder: &mut HirBuilder) -> Type { + let type_name = match val.get("type").and_then(|v| v.as_str()) { + Some(name) => name, + None => return builder.make_type(), + }; + match type_name { + "GenericTypeAnnotation" => { + // Check if it's Array + if let Some(id) = val.get("id") { + if id.get("type").and_then(|v| v.as_str()) == Some("Identifier") { + if id.get("name").and_then(|v| v.as_str()) == Some("Array") { + return Type::Object { + shape_id: Some("BuiltInArray".to_string()), + }; + } + } + } + builder.make_type() + } + "TSTypeReference" => { + if let Some(type_name_val) = val.get("typeName") { + if type_name_val.get("type").and_then(|v| v.as_str()) == Some("Identifier") { + if type_name_val.get("name").and_then(|v| v.as_str()) == Some("Array") { + return Type::Object { + shape_id: Some("BuiltInArray".to_string()), + }; + } + } + } + builder.make_type() + } + "ArrayTypeAnnotation" | "TSArrayType" => Type::Object { + shape_id: Some("BuiltInArray".to_string()), + }, + "BooleanLiteralTypeAnnotation" + | "BooleanTypeAnnotation" + | "NullLiteralTypeAnnotation" + | "NumberLiteralTypeAnnotation" + | "NumberTypeAnnotation" + | "StringLiteralTypeAnnotation" + | "StringTypeAnnotation" + | "TSBooleanKeyword" + | "TSNullKeyword" + | "TSNumberKeyword" + | "TSStringKeyword" + | "TSSymbolKeyword" + | "TSUndefinedKeyword" + | "TSVoidKeyword" + | "VoidTypeAnnotation" => Type::Primitive, + _ => builder.make_type(), + } +} + +/// Gather captured context variables for a nested function. +/// +/// Walks through all identifier references (via `reference_to_binding`) and +/// checks which ones resolve to bindings declared in scopes between the +/// function's parent scope and the component scope. These are "free variables" +/// that become the function's `context`. +fn gather_captured_context( + scope_info: &ScopeInfo, + function_scope: react_compiler_ast::scope::ScopeId, + component_scope: react_compiler_ast::scope::ScopeId, + func_start: u32, + func_end: u32, + identifier_locs: &IdentifierLocIndex, +) -> IndexMap> { + let parent_scope = scope_info.scopes[function_scope.0 as usize].parent; + let pure_scopes = match parent_scope { + Some(parent) => capture_scopes(scope_info, parent, component_scope), + None => IndexSet::new(), + }; + + let mut captured = + IndexMap::>::new(); + + for (&ref_start, &binding_id) in &scope_info.reference_to_binding { + if ref_start < func_start || ref_start >= func_end { + continue; + } + let binding = &scope_info.bindings[binding_id.0 as usize]; + // Skip references that are actually the binding's own declaration site + // (e.g., the function name in `function x() {}` is mapped in referenceToBinding + // but is not a true captured reference) + if binding.declaration_start == Some(ref_start) { + continue; + } + // Skip function/class declaration names that are not expression references. + // In the TS, gatherCapturedContext traverses with an Expression visitor, so + // it never encounters function declaration names. But reference_to_binding + // includes constant violations for function redeclarations (e.g., the second + // `function x() {}` in a scope), so we must filter them out here. + if let Some(entry) = identifier_locs.get(&ref_start) { + if entry.is_declaration_name { + continue; + } + } + // Skip type-only bindings (e.g., Flow/TypeScript type aliases) + // These are not runtime values and should not be captured as context + if binding.declaration_type == "TypeAlias" + || binding.declaration_type == "OpaqueType" + || binding.declaration_type == "InterfaceDeclaration" + || binding.declaration_type == "TSTypeAliasDeclaration" + || binding.declaration_type == "TSInterfaceDeclaration" + || binding.declaration_type == "TSEnumDeclaration" + { + continue; + } + if pure_scopes.contains(&binding.scope) && !captured.contains_key(&binding.id) { + let loc = identifier_locs.get(&ref_start).map(|entry| { + // For JSX identifiers that are part of an opening element name, + // use the JSXOpeningElement's loc (which spans the full tag) to match + // the TS behavior where handleMaybeDependency receives the + // JSXOpeningElement path and uses path.node.loc. + if let Some(oe_loc) = &entry.opening_element_loc { + oe_loc.clone() + } else { + entry.loc.clone() + } + }); + captured.insert(binding.id, loc); + } + } + + captured +} + +fn capture_scopes( + scope_info: &ScopeInfo, + from: react_compiler_ast::scope::ScopeId, + to: react_compiler_ast::scope::ScopeId, +) -> IndexSet { + let mut result = IndexSet::new(); + let mut current = Some(from); + while let Some(scope_id) = current { + result.insert(scope_id); + if scope_id == to { + break; + } + current = scope_info.scopes[scope_id.0 as usize].parent; + } + result +} + +/// The style of assignment (used internally by lower_assignment). +#[derive(Clone, Copy)] +pub enum AssignmentStyle { + /// Assignment via `=` + Assignment, + /// Destructuring assignment + Destructure, +} + +/// Collect locations of fbt:enum, fbt:plural, fbt:pronoun sub-tags +/// within the children of an fbt/fbs JSX element. +fn collect_fbt_sub_tags( + children: &[react_compiler_ast::jsx::JSXChild], + tag_name: &str, + enum_locs: &mut Vec>, + plural_locs: &mut Vec>, + pronoun_locs: &mut Vec>, +) { + use react_compiler_ast::jsx::{JSXChild, JSXElementName}; + for child in children { + match child { + JSXChild::JSXElement(el) => { + // Check if the opening element name is a namespaced name matching the fbt tag + if let JSXElementName::JSXNamespacedName(ns) = &el.opening_element.name { + if ns.namespace.name == tag_name { + let loc = convert_opt_loc(&ns.base.loc); + match ns.name.name.as_str() { + "enum" => enum_locs.push(loc), + "plural" => plural_locs.push(loc), + "pronoun" => pronoun_locs.push(loc), + _ => {} + } + } + } + // Also recurse into children + collect_fbt_sub_tags(&el.children, tag_name, enum_locs, plural_locs, pronoun_locs); + } + JSXChild::JSXFragment(frag) => { + collect_fbt_sub_tags( + &frag.children, + tag_name, + enum_locs, + plural_locs, + pronoun_locs, + ); + } + _ => {} + } + } +} diff --git a/crates/react_compiler_lowering/src/find_context_identifiers.rs b/crates/react_compiler_lowering/src/find_context_identifiers.rs new file mode 100644 index 000000000000..aa5a4c54bdfc --- /dev/null +++ b/crates/react_compiler_lowering/src/find_context_identifiers.rs @@ -0,0 +1,290 @@ +//! Rust equivalent of the TypeScript `FindContextIdentifiers` pass. +//! +//! Determines which bindings need StoreContext/LoadContext semantics by +//! walking the AST with scope tracking to find variables that cross +//! function boundaries. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_ast::{ + expressions::*, + patterns::*, + scope::*, + statements::FunctionDeclaration, + visitor::{AstWalker, Visitor}, +}; + +use crate::FunctionNode; + +#[derive(Default)] +struct BindingInfo { + reassigned: bool, + reassigned_by_inner_fn: bool, + referenced_by_inner_fn: bool, +} + +struct ContextIdentifierVisitor<'a> { + scope_info: &'a ScopeInfo, + /// Stack of inner function scopes encountered during traversal. + /// Empty when at the top level of the function being compiled. + function_stack: Vec, + binding_info: HashMap, +} + +impl<'a> ContextIdentifierVisitor<'a> { + fn push_function_scope(&mut self, start: Option) { + if let Some(start) = start { + if let Some(&scope) = self.scope_info.node_to_scope.get(&start) { + self.function_stack.push(scope); + } + } + } + + fn pop_function_scope(&mut self, start: Option) { + if start + .and_then(|s| self.scope_info.node_to_scope.get(&s)) + .is_some() + { + self.function_stack.pop(); + } + } + + fn handle_reassignment_identifier(&mut self, name: &str, current_scope: ScopeId) { + if let Some(binding_id) = self.scope_info.get_binding(current_scope, name) { + let info = self.binding_info.entry(binding_id).or_default(); + info.reassigned = true; + if let Some(&fn_scope) = self.function_stack.last() { + let binding = &self.scope_info.bindings[binding_id.0 as usize]; + if is_captured_by_function(self.scope_info, binding.scope, fn_scope) { + info.reassigned_by_inner_fn = true; + } + } + } + } +} + +impl<'ast> Visitor<'ast> for ContextIdentifierVisitor<'_> { + fn enter_function_declaration(&mut self, node: &'ast FunctionDeclaration, _: &[ScopeId]) { + self.push_function_scope(node.base.start); + } + + fn leave_function_declaration(&mut self, node: &'ast FunctionDeclaration, _: &[ScopeId]) { + self.pop_function_scope(node.base.start); + } + + fn enter_function_expression(&mut self, node: &'ast FunctionExpression, _: &[ScopeId]) { + self.push_function_scope(node.base.start); + } + + fn leave_function_expression(&mut self, node: &'ast FunctionExpression, _: &[ScopeId]) { + self.pop_function_scope(node.base.start); + } + + fn enter_arrow_function_expression( + &mut self, + node: &'ast ArrowFunctionExpression, + _: &[ScopeId], + ) { + self.push_function_scope(node.base.start); + } + + fn leave_arrow_function_expression( + &mut self, + node: &'ast ArrowFunctionExpression, + _: &[ScopeId], + ) { + self.pop_function_scope(node.base.start); + } + + fn enter_object_method(&mut self, node: &'ast ObjectMethod, _: &[ScopeId]) { + self.push_function_scope(node.base.start); + } + + fn leave_object_method(&mut self, node: &'ast ObjectMethod, _: &[ScopeId]) { + self.pop_function_scope(node.base.start); + } + + fn enter_identifier(&mut self, node: &'ast Identifier, _scope_stack: &[ScopeId]) { + let start = match node.base.start { + Some(s) => s, + None => return, + }; + // Only process identifiers that resolve to a binding (referenced or + // declaration) + let binding_id = match self.scope_info.reference_to_binding.get(&start) { + Some(&id) => id, + None => return, + }; + // If not inside a nested function, nothing to track + let &fn_scope = match self.function_stack.last() { + Some(s) => s, + None => return, + }; + let binding = &self.scope_info.bindings[binding_id.0 as usize]; + if is_captured_by_function(self.scope_info, binding.scope, fn_scope) { + let info = self.binding_info.entry(binding_id).or_default(); + info.referenced_by_inner_fn = true; + } + } + + fn enter_assignment_expression( + &mut self, + node: &'ast AssignmentExpression, + scope_stack: &[ScopeId], + ) { + let current_scope = scope_stack + .last() + .copied() + .unwrap_or(self.scope_info.program_scope); + walk_lval_for_reassignment(self, &node.left, current_scope); + } + + fn enter_update_expression(&mut self, node: &'ast UpdateExpression, scope_stack: &[ScopeId]) { + if let Expression::Identifier(ident) = node.argument.as_ref() { + let current_scope = scope_stack + .last() + .copied() + .unwrap_or(self.scope_info.program_scope); + self.handle_reassignment_identifier(&ident.name, current_scope); + } + } +} + +/// Recursively walk an LVal pattern to find all reassignment target +/// identifiers. +fn walk_lval_for_reassignment( + visitor: &mut ContextIdentifierVisitor<'_>, + pattern: &PatternLike, + current_scope: ScopeId, +) { + match pattern { + PatternLike::Identifier(ident) => { + visitor.handle_reassignment_identifier(&ident.name, current_scope); + } + PatternLike::ArrayPattern(pat) => { + for element in &pat.elements { + if let Some(el) = element { + walk_lval_for_reassignment(visitor, el, current_scope); + } + } + } + PatternLike::ObjectPattern(pat) => { + for prop in &pat.properties { + match prop { + ObjectPatternProperty::ObjectProperty(p) => { + walk_lval_for_reassignment(visitor, &p.value, current_scope); + } + ObjectPatternProperty::RestElement(p) => { + walk_lval_for_reassignment(visitor, &p.argument, current_scope); + } + } + } + } + PatternLike::AssignmentPattern(pat) => { + walk_lval_for_reassignment(visitor, &pat.left, current_scope); + } + PatternLike::RestElement(pat) => { + walk_lval_for_reassignment(visitor, &pat.argument, current_scope); + } + PatternLike::MemberExpression(_) => { + // Interior mutability - not a variable reassignment + } + } +} + +/// Check if a binding declared at `binding_scope` is captured by a function at +/// `function_scope`. Returns true if the binding is declared above the function +/// (in the parent scope or higher). +fn is_captured_by_function( + scope_info: &ScopeInfo, + binding_scope: ScopeId, + function_scope: ScopeId, +) -> bool { + let fn_parent = match scope_info.scopes[function_scope.0 as usize].parent { + Some(p) => p, + None => return false, + }; + if binding_scope == fn_parent { + return true; + } + // Walk up from fn_parent to see if binding_scope is an ancestor + let mut current = scope_info.scopes[fn_parent.0 as usize].parent; + while let Some(scope_id) = current { + if scope_id == binding_scope { + return true; + } + current = scope_info.scopes[scope_id.0 as usize].parent; + } + false +} + +/// Find context identifiers for a function: variables that are captured across +/// function boundaries and need StoreContext/LoadContext semantics. +/// +/// A binding is a context identifier if: +/// - It is reassigned from inside a nested function (`reassignedByInnerFn`), OR +/// - It is reassigned AND referenced from inside a nested function (`reassigned +/// && referencedByInnerFn`) +/// +/// This is the Rust equivalent of the TypeScript `FindContextIdentifiers` pass. +pub fn find_context_identifiers( + func: &FunctionNode<'_>, + scope_info: &ScopeInfo, +) -> HashSet { + let func_start = match func { + FunctionNode::FunctionDeclaration(d) => d.base.start.unwrap_or(0), + FunctionNode::FunctionExpression(e) => e.base.start.unwrap_or(0), + FunctionNode::ArrowFunctionExpression(a) => a.base.start.unwrap_or(0), + }; + let func_scope = scope_info + .node_to_scope + .get(&func_start) + .copied() + .unwrap_or(scope_info.program_scope); + + let mut visitor = ContextIdentifierVisitor { + scope_info, + function_stack: Vec::new(), + binding_info: HashMap::new(), + }; + let mut walker = AstWalker::with_initial_scope(scope_info, func_scope); + + // Walk params and body (like Babel's func.traverse()) + match func { + FunctionNode::FunctionDeclaration(d) => { + for param in &d.params { + walker.walk_pattern(&mut visitor, param); + } + walker.walk_block_statement(&mut visitor, &d.body); + } + FunctionNode::FunctionExpression(e) => { + for param in &e.params { + walker.walk_pattern(&mut visitor, param); + } + walker.walk_block_statement(&mut visitor, &e.body); + } + FunctionNode::ArrowFunctionExpression(a) => { + for param in &a.params { + walker.walk_pattern(&mut visitor, param); + } + match a.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + walker.walk_block_statement(&mut visitor, block); + } + ArrowFunctionBody::Expression(expr) => { + walker.walk_expression(&mut visitor, expr); + } + } + } + } + + // Collect results + visitor + .binding_info + .into_iter() + .filter(|(_, info)| { + info.reassigned_by_inner_fn || (info.reassigned && info.referenced_by_inner_fn) + }) + .map(|(id, _)| id) + .collect() +} diff --git a/crates/react_compiler_lowering/src/hir_builder.rs b/crates/react_compiler_lowering/src/hir_builder.rs new file mode 100644 index 000000000000..8e29cb158c4e --- /dev/null +++ b/crates/react_compiler_lowering/src/hir_builder.rs @@ -0,0 +1,1367 @@ +use indexmap::{IndexMap, IndexSet}; +use react_compiler_ast::scope::{BindingId, ImportBindingKind, ScopeId, ScopeInfo}; +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, CompilerErrorDetail, ErrorCategory, +}; +use react_compiler_hir::{ + environment::Environment, + visitors::{each_terminal_successor, terminal_fallthrough}, + *, +}; + +use crate::identifier_loc_index::IdentifierLocIndex; + +// --------------------------------------------------------------------------- +// Reserved word check (matches TS isReservedWord) +// --------------------------------------------------------------------------- + +fn is_reserved_word(s: &str) -> bool { + matches!( + s, + "break" + | "case" + | "catch" + | "continue" + | "debugger" + | "default" + | "do" + | "else" + | "finally" + | "for" + | "function" + | "if" + | "in" + | "instanceof" + | "new" + | "return" + | "switch" + | "this" + | "throw" + | "try" + | "typeof" + | "var" + | "void" + | "while" + | "with" + | "class" + | "const" + | "enum" + | "export" + | "extends" + | "import" + | "super" + | "implements" + | "interface" + | "let" + | "package" + | "private" + | "protected" + | "public" + | "static" + | "yield" + | "null" + | "true" + | "false" + | "delete" + ) +} + +// --------------------------------------------------------------------------- +// Scope types for tracking break/continue targets +// --------------------------------------------------------------------------- + +enum Scope { + Loop { + label: Option, + continue_block: BlockId, + break_block: BlockId, + }, + Label { + label: String, + break_block: BlockId, + }, + Switch { + label: Option, + break_block: BlockId, + }, +} + +impl Scope { + fn label(&self) -> Option<&str> { + match self { + Scope::Loop { label, .. } => label.as_deref(), + Scope::Label { label, .. } => Some(label.as_str()), + Scope::Switch { label, .. } => label.as_deref(), + } + } + + fn break_block(&self) -> BlockId { + match self { + Scope::Loop { break_block, .. } => *break_block, + Scope::Label { break_block, .. } => *break_block, + Scope::Switch { break_block, .. } => *break_block, + } + } +} + +// --------------------------------------------------------------------------- +// WipBlock: a block under construction that does not yet have a terminal +// --------------------------------------------------------------------------- + +pub struct WipBlock { + pub id: BlockId, + pub instructions: Vec, + pub kind: BlockKind, +} + +fn new_block(id: BlockId, kind: BlockKind) -> WipBlock { + WipBlock { + id, + kind, + instructions: Vec::new(), + } +} + +// --------------------------------------------------------------------------- +// HirBuilder: helper struct for constructing a CFG +// --------------------------------------------------------------------------- + +pub struct HirBuilder<'a> { + completed: IndexMap, + current: WipBlock, + entry: BlockId, + scopes: Vec, + /// Context identifiers: variables captured from an outer scope. + /// Maps the outer scope's BindingId to the source location where it was + /// referenced. + context: IndexMap>, + /// Resolved bindings: maps a BindingId to the HIR IdentifierId created for + /// it. + bindings: IndexMap, + /// Names already used by bindings, for collision avoidance. + /// Maps name string -> how many times it has been used (for appending _0, + /// _1, ...). + used_names: IndexMap, + env: &'a mut Environment, + scope_info: &'a ScopeInfo, + exception_handler_stack: Vec, + /// Flat instruction table being built up. + instruction_table: Vec, + /// Traversal context: counts the number of `fbt` tag parents + /// of the current babel node. + pub fbt_depth: u32, + /// The scope of the function being compiled (for context identifier + /// checks). + function_scope: ScopeId, + /// The scope of the outermost component/hook function (for + /// gather_captured_context). + component_scope: ScopeId, + /// Set of BindingIds for variables declared in scopes between + /// component_scope and any inner function scope, that are referenced + /// from an inner function scope. These need StoreContext/LoadContext + /// instead of StoreLocal/LoadLocal. + context_identifiers: std::collections::HashSet, + /// Index mapping identifier byte offsets to source locations and JSX + /// status. + identifier_locs: &'a IdentifierLocIndex, +} + +impl<'a> HirBuilder<'a> { + // ----------------------------------------------------------------------- + // M2: Core methods + // ----------------------------------------------------------------------- + + /// Create a new HirBuilder. + /// + /// - `env`: the shared environment (counters, arenas, error accumulator) + /// - `scope_info`: the scope information from the AST + /// - `function_scope`: the ScopeId of the function being compiled + /// - `bindings`: optional pre-existing bindings (e.g., from a parent + /// function) + /// - `context`: optional pre-existing captured context map + /// - `entry_block_kind`: the kind of the entry block (defaults to `Block`) + pub fn new( + env: &'a mut Environment, + scope_info: &'a ScopeInfo, + function_scope: ScopeId, + component_scope: ScopeId, + context_identifiers: std::collections::HashSet, + bindings: Option>, + context: Option>>, + entry_block_kind: Option, + used_names: Option>, + identifier_locs: &'a IdentifierLocIndex, + ) -> Self { + let entry = env.next_block_id(); + let kind = entry_block_kind.unwrap_or(BlockKind::Block); + HirBuilder { + completed: IndexMap::new(), + current: new_block(entry, kind), + entry, + scopes: Vec::new(), + context: context.unwrap_or_default(), + bindings: bindings.unwrap_or_default(), + used_names: used_names.unwrap_or_default(), + env, + scope_info, + exception_handler_stack: Vec::new(), + instruction_table: Vec::new(), + fbt_depth: 0, + function_scope, + component_scope, + context_identifiers, + identifier_locs, + } + } + + /// Access the environment. + pub fn environment(&self) -> &Environment { + self.env + } + + /// Access the environment mutably. + pub fn environment_mut(&mut self) -> &mut Environment { + self.env + } + + /// Create a new unique TypeVar type, allocated from the environment's type + /// arena so that TypeIds are consistent with identifier type slots. + pub fn make_type(&mut self) -> Type { + let type_id = self.env.make_type(); + Type::TypeVar { id: type_id } + } + + /// Access the scope info. + pub fn scope_info(&self) -> &ScopeInfo { + self.scope_info + } + + /// Look up the source location of an identifier by its byte offset. + pub fn get_identifier_loc(&self, offset: u32) -> Option { + self.identifier_locs + .get(&offset) + .map(|entry| entry.loc.clone()) + } + + /// Check whether a byte offset corresponds to a JSXIdentifier node. + pub fn is_jsx_identifier(&self, offset: u32) -> bool { + self.identifier_locs + .get(&offset) + .is_some_and(|entry| entry.is_jsx) + } + + /// Access the function scope (the scope of the function being compiled). + pub fn function_scope(&self) -> ScopeId { + self.function_scope + } + + /// Access the component scope. + pub fn component_scope(&self) -> ScopeId { + self.component_scope + } + + /// Access the context map. + pub fn context(&self) -> &IndexMap> { + &self.context + } + + /// Access the pre-computed context identifiers set. + pub fn context_identifiers(&self) -> &std::collections::HashSet { + &self.context_identifiers + } + + /// Add a binding to the context identifiers set (used by hoisting). + pub fn add_context_identifier(&mut self, binding_id: BindingId) { + self.context_identifiers.insert(binding_id); + } + + /// Access scope_info and environment mutably at the same time. + /// This is safe because they are disjoint fields, but Rust's borrow checker + /// can't prove this through method calls alone. + pub fn scope_info_and_env_mut(&mut self) -> (&ScopeInfo, &mut Environment) { + (self.scope_info, self.env) + } + + /// Access the identifier location index. + /// Returns the 'a reference to avoid conflicts with mutable borrows on + /// self. + pub fn identifier_locs(&self) -> &'a IdentifierLocIndex { + self.identifier_locs + } + + /// Access the bindings map. + pub fn bindings(&self) -> &IndexMap { + &self.bindings + } + + /// Access the used names map. + pub fn used_names(&self) -> &IndexMap { + &self.used_names + } + + /// Merge used names from a child builder back into this builder. + /// This ensures name deduplication works across function scopes. + pub fn merge_used_names(&mut self, child_used_names: IndexMap) { + for (name, binding_id) in child_used_names { + self.used_names.entry(name).or_insert(binding_id); + } + } + + /// Merge bindings (binding_id -> IdentifierId) from a child builder back + /// into this builder. This matches TS behavior where parent and child + /// share the same #bindings map by reference, so bindings resolved by + /// the child are automatically visible to the parent. + pub fn merge_bindings(&mut self, child_bindings: IndexMap) { + for (binding_id, identifier_id) in child_bindings { + self.bindings.entry(binding_id).or_insert(identifier_id); + } + } + + /// Push an instruction onto the current block. + /// + /// Adds the instruction to the flat instruction table and records + /// its InstructionId in the current block's instruction list. + /// + /// If an exception handler is active, also emits a MaybeThrow terminal + /// after the instruction to model potential control flow to the handler, + /// then continues in a new block. + pub fn push(&mut self, instruction: Instruction) { + let loc = instruction.loc.clone(); + let instr_id = InstructionId(self.instruction_table.len() as u32); + self.instruction_table.push(instruction); + self.current.instructions.push(instr_id); + + if let Some(&handler) = self.exception_handler_stack.last() { + let continuation = self.reserve(self.current_block_kind()); + self.terminate_with_continuation( + Terminal::MaybeThrow { + continuation: continuation.id, + handler: Some(handler), + id: EvaluationOrder(0), + loc, + effects: None, + }, + continuation, + ); + } + } + + /// Terminate the current block with the given terminal and start a new + /// block. + /// + /// If `next_block_kind` is `Some`, a new current block is created with that + /// kind. Returns the BlockId of the completed block. + pub fn terminate(&mut self, terminal: Terminal, next_block_kind: Option) -> BlockId { + // The placeholder block created here (BlockId(u32::MAX)) is only used when + // next_block_kind is None, meaning this is the final terminate() call. + // It will never be read or completed because build() consumes self + // immediately after, and no further operations should occur on the builder. + let wip = std::mem::replace( + &mut self.current, + new_block(BlockId(u32::MAX), BlockKind::Block), + ); + let block_id = wip.id; + + self.completed.insert( + block_id, + BasicBlock { + kind: wip.kind, + id: block_id, + instructions: wip.instructions, + terminal, + preds: IndexSet::new(), + phis: Vec::new(), + }, + ); + + if let Some(kind) = next_block_kind { + let next_id = self.env.next_block_id(); + self.current = new_block(next_id, kind); + } + block_id + } + + /// Terminate the current block with the given terminal, and set + /// a previously reserved block as the new current block. + pub fn terminate_with_continuation(&mut self, terminal: Terminal, continuation: WipBlock) { + let wip = std::mem::replace(&mut self.current, continuation); + let block_id = wip.id; + self.completed.insert( + block_id, + BasicBlock { + kind: wip.kind, + id: block_id, + instructions: wip.instructions, + terminal, + preds: IndexSet::new(), + phis: Vec::new(), + }, + ); + } + + /// Reserve a new block so it can be referenced before construction. + /// Use `terminate_with_continuation()` to make it current, or `complete()` + /// to save it directly. + pub fn reserve(&mut self, kind: BlockKind) -> WipBlock { + let id = self.env.next_block_id(); + new_block(id, kind) + } + + /// Save a previously reserved block as completed with the given terminal. + pub fn complete(&mut self, block: WipBlock, terminal: Terminal) { + let block_id = block.id; + self.completed.insert( + block_id, + BasicBlock { + kind: block.kind, + id: block_id, + instructions: block.instructions, + terminal, + preds: IndexSet::new(), + phis: Vec::new(), + }, + ); + } + + /// Sets the given wip block as current, executes the closure to populate + /// it and obtain its terminal, then completes the block and restores the + /// previous current block. + pub fn enter_reserved(&mut self, wip: WipBlock, f: impl FnOnce(&mut Self) -> Terminal) { + let prev = std::mem::replace(&mut self.current, wip); + let terminal = f(self); + let completed_wip = std::mem::replace(&mut self.current, prev); + self.completed.insert( + completed_wip.id, + BasicBlock { + kind: completed_wip.kind, + id: completed_wip.id, + instructions: completed_wip.instructions, + terminal, + preds: IndexSet::new(), + phis: Vec::new(), + }, + ); + } + + /// Like `enter_reserved`, but the closure returns a `Result`. + pub fn try_enter_reserved( + &mut self, + wip: WipBlock, + f: impl FnOnce(&mut Self) -> Result, + ) -> Result<(), CompilerDiagnostic> { + let prev = std::mem::replace(&mut self.current, wip); + let terminal = f(self)?; + let completed_wip = std::mem::replace(&mut self.current, prev); + self.completed.insert( + completed_wip.id, + BasicBlock { + kind: completed_wip.kind, + id: completed_wip.id, + instructions: completed_wip.instructions, + terminal, + preds: IndexSet::new(), + phis: Vec::new(), + }, + ); + Ok(()) + } + + /// Create a new block, set it as current, run the closure to populate it + /// and obtain its terminal, complete the block, and restore the previous + /// current block. Returns the new block's BlockId. + pub fn enter( + &mut self, + kind: BlockKind, + f: impl FnOnce(&mut Self, BlockId) -> Terminal, + ) -> BlockId { + let wip = self.reserve(kind); + let wip_id = wip.id; + self.enter_reserved(wip, |this| f(this, wip_id)); + wip_id + } + + /// Like `enter`, but the closure returns a `Result`. + pub fn try_enter( + &mut self, + kind: BlockKind, + f: impl FnOnce(&mut Self, BlockId) -> Result, + ) -> Result { + let wip = self.reserve(kind); + let wip_id = wip.id; + self.try_enter_reserved(wip, |this| f(this, wip_id))?; + Ok(wip_id) + } + + /// Push an exception handler, run the closure, then pop the handler. + pub fn enter_try_catch(&mut self, handler: BlockId, f: impl FnOnce(&mut Self)) { + self.exception_handler_stack.push(handler); + f(self); + self.exception_handler_stack.pop(); + } + + /// Like `enter_try_catch`, but the closure returns a `Result`. + pub fn try_enter_try_catch( + &mut self, + handler: BlockId, + f: impl FnOnce(&mut Self) -> Result<(), CompilerDiagnostic>, + ) -> Result<(), CompilerDiagnostic> { + self.exception_handler_stack.push(handler); + let result = f(self); + self.exception_handler_stack.pop(); + result + } + + /// Return the top of the exception handler stack, or None. + pub fn resolve_throw_handler(&self) -> Option { + self.exception_handler_stack.last().copied() + } + + /// Push a Loop scope, run the closure, pop and verify. + pub fn loop_scope( + &mut self, + label: Option, + continue_block: BlockId, + break_block: BlockId, + f: impl FnOnce(&mut Self) -> Result, + ) -> Result { + self.scopes.push(Scope::Loop { + label: label.clone(), + continue_block, + break_block, + }); + let value = f(self)?; + let last = self + .scopes + .pop() + .expect("Mismatched loop scope: stack empty"); + match &last { + Scope::Loop { + label: l, + continue_block: c, + break_block: b, + } => { + assert!( + *l == label && *c == continue_block && *b == break_block, + "Mismatched loop scope" + ); + } + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Mismatched loop scope: expected Loop, got other", + None, + )) + } + } + Ok(value) + } + + /// Push a Label scope, run the closure, pop and verify. + pub fn label_scope( + &mut self, + label: String, + break_block: BlockId, + f: impl FnOnce(&mut Self) -> Result, + ) -> Result { + self.scopes.push(Scope::Label { + label: label.clone(), + break_block, + }); + let value = f(self)?; + let last = self + .scopes + .pop() + .expect("Mismatched label scope: stack empty"); + match &last { + Scope::Label { + label: l, + break_block: b, + } => { + assert!(*l == label && *b == break_block, "Mismatched label scope"); + } + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Mismatched label scope: expected Label, got other", + None, + )) + } + } + Ok(value) + } + + /// Push a Switch scope, run the closure, pop and verify. + pub fn switch_scope( + &mut self, + label: Option, + break_block: BlockId, + f: impl FnOnce(&mut Self) -> Result, + ) -> Result { + self.scopes.push(Scope::Switch { + label: label.clone(), + break_block, + }); + let value = f(self)?; + let last = self + .scopes + .pop() + .expect("Mismatched switch scope: stack empty"); + match &last { + Scope::Switch { + label: l, + break_block: b, + } => { + assert!(*l == label && *b == break_block, "Mismatched switch scope"); + } + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Mismatched switch scope: expected Switch, got other", + None, + )) + } + } + Ok(value) + } + + /// Look up the break target for the given label (or the innermost + /// loop/switch if label is None). + pub fn lookup_break(&self, label: Option<&str>) -> Result { + for scope in self.scopes.iter().rev() { + match scope { + Scope::Loop { .. } | Scope::Switch { .. } if label.is_none() => { + return Ok(scope.break_block()); + } + _ if label.is_some() && scope.label() == label => { + return Ok(scope.break_block()); + } + _ => continue, + } + } + Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected a loop or switch to be in scope for break", + None, + )) + } + + /// Look up the continue target for the given label (or the innermost + /// loop if label is None). Only loops support continue. + pub fn lookup_continue(&self, label: Option<&str>) -> Result { + for scope in self.scopes.iter().rev() { + match scope { + Scope::Loop { + label: scope_label, + continue_block, + .. + } => { + if label.is_none() || label == scope_label.as_deref() { + return Ok(*continue_block); + } + } + _ => { + if label.is_some() && scope.label() == label { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Continue may only refer to a labeled loop", + None, + )); + } + } + } + } + Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected a loop to be in scope for continue", + None, + )) + } + + /// Create a temporary identifier with a fresh id, returning its + /// IdentifierId. + pub fn make_temporary(&mut self, loc: Option) -> IdentifierId { + let id = self.env.next_identifier_id(); + // Update the loc on the allocated identifier + self.env.identifiers[id.0 as usize].loc = loc; + id + } + + /// Set the source location for an identifier. + pub fn set_identifier_loc(&mut self, id: IdentifierId, loc: Option) { + self.env.identifiers[id.0 as usize].loc = loc; + } + + /// Record an error on the environment. + /// Returns `Err` for Invariant errors (matching TS throw behavior). + pub fn record_error(&mut self, error: CompilerErrorDetail) -> Result<(), CompilerError> { + self.env.record_error(error) + } + + /// Record a diagnostic on the environment. + pub fn record_diagnostic(&mut self, diagnostic: CompilerDiagnostic) { + self.env.record_diagnostic(diagnostic); + } + + /// Check if a name has a local binding (non-module-level). + /// This is used for checking if fbt/fbs JSX tags are local bindings + /// (which is not supported). Unlike resolve_identifier, this doesn't + /// require a source position. + pub fn has_local_binding(&self, name: &str) -> bool { + // Check used_names to see if this name has been bound locally + if let Some(&binding_id) = self.used_names.get(name) { + // Check that the binding is NOT in the program scope (i.e., it's local) + let binding = &self.scope_info.bindings[binding_id.0 as usize]; + return binding.scope != self.scope_info.program_scope; + } + false + } + + /// Return the kind of the current block. + pub fn current_block_kind(&self) -> BlockKind { + self.current.kind + } + + /// Construct the final HIR and instruction table from the completed blocks. + /// + /// Performs these post-build passes: + /// 1. Reverse-postorder sort + unreachable block removal + /// 2. Check for unreachable blocks containing FunctionExpression + /// instructions + /// 3. Remove unreachable for-loop updates + /// 4. Remove dead do-while statements + /// 5. Remove unnecessary try-catch + /// 6. Number all instructions and terminals + /// 7. Mark predecessor blocks + pub fn build( + mut self, + ) -> Result< + ( + HIR, + Vec, + IndexMap, + IndexMap, + ), + CompilerError, + > { + let mut hir = HIR { + blocks: std::mem::take(&mut self.completed), + entry: self.entry, + }; + + let mut instructions = std::mem::take(&mut self.instruction_table); + + let rpo_blocks = get_reverse_postordered_blocks(&hir, &instructions); + + // Check for unreachable blocks that contain FunctionExpression instructions. + // These could contain hoisted declarations that we can't safely remove. + for (id, block) in &hir.blocks { + if !rpo_blocks.contains_key(id) { + let has_function_expr = block.instructions.iter().any(|&instr_id| { + matches!( + instructions[instr_id.0 as usize].value, + InstructionValue::FunctionExpression { .. } + ) + }); + if has_function_expr { + let loc = block + .instructions + .first() + .and_then(|&i| instructions[i.0 as usize].loc.clone()) + .or_else(|| block.terminal.loc().copied()); + self.env.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Support functions with unreachable code that may contain hoisted \ + declarations" + .to_string(), + description: None, + loc, + suggestions: None, + })?; + } + } + } + + hir.blocks = rpo_blocks; + + remove_unreachable_for_updates(&mut hir); + remove_dead_do_while_statements(&mut hir); + remove_unnecessary_try_catch(&mut hir); + mark_instruction_ids(&mut hir, &mut instructions); + mark_predecessors(&mut hir); + + let used_names = self.used_names; + let bindings = self.bindings; + Ok((hir, instructions, used_names, bindings)) + } + + // ----------------------------------------------------------------------- + // M3: Binding resolution methods + // ----------------------------------------------------------------------- + + /// Map a BindingId to an HIR IdentifierId. + /// + /// On first encounter, creates a new Identifier with the given name and a + /// fresh id. On subsequent encounters, returns the cached IdentifierId. + /// Handles name collisions by appending `_0`, `_1`, etc. + /// + /// Records errors for variables named 'fbt' or 'this'. + pub fn resolve_binding( + &mut self, + name: &str, + binding_id: BindingId, + ) -> Result { + self.resolve_binding_with_loc(name, binding_id, None) + } + + /// Map a BindingId to an HIR IdentifierId, with an optional source + /// location. + pub fn resolve_binding_with_loc( + &mut self, + name: &str, + binding_id: BindingId, + loc: Option, + ) -> Result { + // Check for unsupported names BEFORE the cache check. + // In TS, resolveBinding records fbt errors when node.name === 'fbt'. After a + // name collision causes a rename (e.g., "fbt" -> "fbt_0"), TS's + // scope.rename changes the AST node's name, preventing subsequent fbt + // error recording. We simulate this by checking whether the + // resolved name for this binding is still "fbt" (not renamed to "fbt_0" etc.). + if name == "fbt" { + // Check if this binding was previously resolved to a renamed version + let should_record_fbt_error = + if let Some(&identifier_id) = self.bindings.get(&binding_id) { + // Already resolved - check if the resolved name is still "fbt" + match &self.env.identifiers[identifier_id.0 as usize].name { + Some(IdentifierName::Named(resolved_name)) => resolved_name == "fbt", + _ => false, + } + } else { + // First resolution - always record + true + }; + if should_record_fbt_error { + let error_loc = self.scope_info.bindings[binding_id.0 as usize] + .declaration_start + .and_then(|start| self.get_identifier_loc(start)) + .or_else(|| loc.clone()); + self.env.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Support local variables named `fbt`".to_string(), + description: Some( + "Local variables named `fbt` may conflict with the fbt plugin and are not \ + yet supported" + .to_string(), + ), + loc: error_loc, + suggestions: None, + })?; + } + } + + // If we've already resolved this binding, return the cached IdentifierId + if let Some(&identifier_id) = self.bindings.get(&binding_id) { + return Ok(identifier_id); + } + + if is_reserved_word(name) { + // Match TS behavior: makeIdentifierName throws for reserved words, + // which propagates as a CompileUnexpectedThrow + CompileError. + // Note: this is normally caught earlier in scope.ts, but kept as a safety net. + self.env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Syntax, + "Expected a non-reserved identifier name", + Some(format!( + "`{}` is a reserved word in JavaScript and cannot be used as an \ + identifier name", + name + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: None, // GeneratedSource in TS + message: Some("reserved word".to_string()), + identifier_name: None, + }), + ); + } + + // Find a unique name: start with the original name, then try name_0, name_1, + // ... + let mut candidate = name.to_string(); + let mut index = 0u32; + loop { + if let Some(&existing_binding_id) = self.used_names.get(&candidate) { + if existing_binding_id == binding_id { + // Same binding, use this name + break; + } + // Name collision with a different binding, try the next suffix + candidate = format!("{}_{}", name, index); + index += 1; + } else { + // Name is available + break; + } + } + + // Record rename if the candidate differs from the original name + if candidate != name { + let binding = &self.scope_info.bindings[binding_id.0 as usize]; + if let Some(decl_start) = binding.declaration_start { + self.env + .renames + .push(react_compiler_hir::environment::BindingRename { + original: name.to_string(), + renamed: candidate.clone(), + declaration_start: decl_start, + }); + } + } + + // Allocate identifier in the arena + let id = self.env.next_identifier_id(); + // Update the name and loc on the allocated identifier + self.env.identifiers[id.0 as usize].name = Some(IdentifierName::Named(candidate.clone())); + // Prefer the binding's declaration loc over the reference loc. + // This matches TS behavior where Babel's resolveBinding returns the + // binding identifier's original loc (the declaration site). + let binding = &self.scope_info.bindings[binding_id.0 as usize]; + let decl_loc = binding + .declaration_start + .and_then(|start| self.get_identifier_loc(start)); + if let Some(ref dl) = decl_loc { + self.env.identifiers[id.0 as usize].loc = Some(dl.clone()); + } else if let Some(ref loc) = loc { + self.env.identifiers[id.0 as usize].loc = Some(loc.clone()); + } + + self.used_names.insert(candidate, binding_id); + self.bindings.insert(binding_id, id); + Ok(id) + } + + /// Set the loc on an identifier to the declaration-site loc. + /// This overrides any previously-set loc (which may have come from a + /// reference site). + pub fn set_identifier_declaration_loc( + &mut self, + id: IdentifierId, + loc: &Option, + ) { + if let Some(loc_val) = loc { + self.env.identifiers[id.0 as usize].loc = Some(loc_val.clone()); + } + } + + /// Resolve an identifier reference to a VariableBinding. + /// + /// Uses ScopeInfo to determine whether the reference is: + /// - Global (no binding found) + /// - ImportDefault, ImportSpecifier, ImportNamespace (program-scope import + /// binding) + /// - ModuleLocal (program-scope non-import binding) + /// - Identifier (local binding, resolved via resolve_binding) + pub fn resolve_identifier( + &mut self, + name: &str, + start_offset: u32, + loc: Option, + ) -> Result { + let binding_data = self.scope_info.resolve_reference(start_offset); + + match binding_data { + None => { + // No binding found: this is a global + Ok(VariableBinding::Global { + name: name.to_string(), + }) + } + Some(binding) => { + // Treat type-only declarations as globals so the compiler + // doesn't try to create/initialize HIR bindings for them. + // TSEnumDeclaration is included because enums inside function + // bodies are lowered as UnsupportedNode and their binding + // is never initialized in HIR. + if matches!( + binding.declaration_type.as_str(), + "TSTypeAliasDeclaration" + | "TSInterfaceDeclaration" + | "TSEnumDeclaration" + | "TSModuleDeclaration" + ) { + return Ok(VariableBinding::Global { + name: name.to_string(), + }); + } + if binding.scope == self.scope_info.program_scope { + // Module-level binding: check import info + Ok(match &binding.import { + Some(import_info) => match import_info.kind { + ImportBindingKind::Default => VariableBinding::ImportDefault { + name: name.to_string(), + module: import_info.source.clone(), + }, + ImportBindingKind::Named => VariableBinding::ImportSpecifier { + name: name.to_string(), + module: import_info.source.clone(), + imported: import_info + .imported + .clone() + .unwrap_or_else(|| name.to_string()), + }, + ImportBindingKind::Namespace => VariableBinding::ImportNamespace { + name: name.to_string(), + module: import_info.source.clone(), + }, + }, + None => VariableBinding::ModuleLocal { + name: name.to_string(), + }, + }) + } else { + // Local binding: resolve via resolve_binding. + // When the resolved binding's name doesn't match the identifier + // being resolved, fall back to a name-based lookup. This handles + // cases like component-syntax where Flow transforms create multiple + // params with the same start position (e.g., both _$$empty_props_placeholder$$ + // and ref have start=106 after the Flow component transform). + let resolved_binding = if binding.name != name { + self.scope_info + .resolve_reference_by_name(name, start_offset) + .unwrap_or(binding) + } else { + binding + }; + let binding_id = resolved_binding.id; + let binding_kind = crate::convert_binding_kind(&resolved_binding.kind); + let identifier_id = self.resolve_binding_with_loc(name, binding_id, loc)?; + Ok(VariableBinding::Identifier { + identifier: identifier_id, + binding_kind, + }) + } + } + } + } + + /// Check if an identifier reference resolves to a context identifier. + /// + /// A context identifier is a variable declared in an ancestor scope of the + /// current function's scope, but NOT in the program scope itself and NOT + /// in the function's own scope. These are "captured" variables from an + /// enclosing function. + pub fn is_context_identifier(&self, _name: &str, start_offset: u32) -> bool { + let binding = self.scope_info.resolve_reference(start_offset); + + match binding { + None => false, + Some(binding_data) => { + // If in program scope, it's a module-level binding, not context + if binding_data.scope == self.scope_info.program_scope { + return false; + } + + // Check if this binding is in the pre-computed context identifiers set. + self.context_identifiers.contains(&binding_data.id) + } + } + } +} + +// --------------------------------------------------------------------------- +// Post-build helper functions +// --------------------------------------------------------------------------- + +/// Compute a reverse-postorder of blocks reachable from the entry. +/// +/// Visits successors in reverse order so that when the postorder list is +/// reversed, sibling edges appear in program order. +/// +/// Blocks not reachable through successors are removed. Blocks that are +/// only reachable as fallthroughs (not through real successor edges) are +/// replaced with empty blocks that have an Unreachable terminal. +pub fn get_reverse_postordered_blocks( + hir: &HIR, + _instructions: &[Instruction], +) -> IndexMap { + let mut visited: IndexSet = IndexSet::new(); + let mut used: IndexSet = IndexSet::new(); + let mut used_fallthroughs: IndexSet = IndexSet::new(); + let mut postorder: Vec = Vec::new(); + + fn visit( + hir: &HIR, + block_id: BlockId, + is_used: bool, + visited: &mut IndexSet, + used: &mut IndexSet, + used_fallthroughs: &mut IndexSet, + postorder: &mut Vec, + ) { + let was_used = used.contains(&block_id); + let was_visited = visited.contains(&block_id); + visited.insert(block_id); + if is_used { + used.insert(block_id); + } + if was_visited && (was_used || !is_used) { + return; + } + + let block = hir + .blocks + .get(&block_id) + .unwrap_or_else(|| panic!("[HIRBuilder] expected block {:?} to exist", block_id)); + + // Visit successors in reverse order so that when we reverse the + // postorder list, sibling edges come out in program order. + let mut successors = each_terminal_successor(&block.terminal); + successors.reverse(); + + let fallthrough = terminal_fallthrough(&block.terminal); + + // Visit fallthrough first (marking as not-yet-used) to ensure its + // block ID is emitted in the correct position. + if let Some(ft) = fallthrough { + if is_used { + used_fallthroughs.insert(ft); + } + visit(hir, ft, false, visited, used, used_fallthroughs, postorder); + } + for successor in successors { + visit( + hir, + successor, + is_used, + visited, + used, + used_fallthroughs, + postorder, + ); + } + + if !was_visited { + postorder.push(block_id); + } + } + + visit( + hir, + hir.entry, + true, + &mut visited, + &mut used, + &mut used_fallthroughs, + &mut postorder, + ); + + let mut blocks = IndexMap::new(); + for block_id in postorder.into_iter().rev() { + let block = hir.blocks.get(&block_id).unwrap(); + if used.contains(&block_id) { + blocks.insert(block_id, block.clone()); + } else if used_fallthroughs.contains(&block_id) { + blocks.insert( + block_id, + BasicBlock { + kind: block.kind, + id: block_id, + instructions: Vec::new(), + terminal: Terminal::Unreachable { + id: block.terminal.evaluation_order(), + loc: block.terminal.loc().copied(), + }, + preds: block.preds.clone(), + phis: Vec::new(), + }, + ); + } + // otherwise this block is unreachable and is dropped + } + + blocks +} + +/// For each block with a `For` terminal whose update block is not in the +/// blocks map, set update to None. +pub fn remove_unreachable_for_updates(hir: &mut HIR) { + let block_ids: IndexSet = hir.blocks.keys().copied().collect(); + for block in hir.blocks.values_mut() { + if let Terminal::For { update, .. } = &mut block.terminal { + if let Some(update_id) = *update { + if !block_ids.contains(&update_id) { + *update = None; + } + } + } + } +} + +/// For each block with a `DoWhile` terminal whose test block is not in +/// the blocks map, replace the terminal with a Goto to the loop block. +pub fn remove_dead_do_while_statements(hir: &mut HIR) { + let block_ids: IndexSet = hir.blocks.keys().copied().collect(); + for block in hir.blocks.values_mut() { + let should_replace = if let Terminal::DoWhile { test, .. } = &block.terminal { + !block_ids.contains(test) + } else { + false + }; + if should_replace { + if let Terminal::DoWhile { + loop_block, + id, + loc, + .. + } = std::mem::replace( + &mut block.terminal, + Terminal::Unreachable { + id: EvaluationOrder(0), + loc: None, + }, + ) { + block.terminal = Terminal::Goto { + block: loop_block, + variant: GotoVariant::Break, + id, + loc, + }; + } + } + } +} + +/// For each block with a `Try` terminal whose handler block is not in +/// the blocks map, replace the terminal with a Goto to the try block. +/// +/// Also cleans up the fallthrough block's predecessors if the handler +/// was the only path to it. +pub fn remove_unnecessary_try_catch(hir: &mut HIR) { + let block_ids: IndexSet = hir.blocks.keys().copied().collect(); + + // Collect the blocks that need replacement and their associated data + let replacements: Vec<(BlockId, BlockId, BlockId, BlockId, Option)> = hir + .blocks + .iter() + .filter_map(|(&block_id, block)| { + if let Terminal::Try { + block: try_block, + handler, + fallthrough, + loc, + .. + } = &block.terminal + { + if !block_ids.contains(handler) { + return Some((block_id, *try_block, *handler, *fallthrough, loc.clone())); + } + } + None + }) + .collect(); + + for (block_id, try_block, handler_id, fallthrough_id, loc) in replacements { + // Replace the terminal + if let Some(block) = hir.blocks.get_mut(&block_id) { + block.terminal = Terminal::Goto { + block: try_block, + id: EvaluationOrder(0), + loc, + variant: GotoVariant::Break, + }; + } + + // Clean up fallthrough predecessor info + if let Some(fallthrough) = hir.blocks.get_mut(&fallthrough_id) { + if fallthrough.preds.len() == 1 && fallthrough.preds.contains(&handler_id) { + // The handler was the only predecessor: remove the fallthrough block + hir.blocks.shift_remove(&fallthrough_id); + } else { + fallthrough.preds.shift_remove(&handler_id); + } + } + } +} + +/// Sequentially number all instructions and terminals starting from 1. +pub fn mark_instruction_ids(hir: &mut HIR, instructions: &mut [Instruction]) { + let mut order: u32 = 0; + for block in hir.blocks.values_mut() { + for &instr_id in &block.instructions { + order += 1; + instructions[instr_id.0 as usize].id = EvaluationOrder(order); + } + order += 1; + block.terminal.set_evaluation_order(EvaluationOrder(order)); + } +} + +/// DFS from entry, for each successor add the predecessor's id to +/// the successor's preds set. +/// +/// Note: This only visits direct successors (via `each_terminal_successor`), +/// not fallthrough blocks. Fallthrough blocks are reached indirectly via +/// Goto terminals from within branching blocks, matching the TypeScript +/// `markPredecessors` behavior. +pub fn mark_predecessors(hir: &mut HIR) { + // Clear all preds first + for block in hir.blocks.values_mut() { + block.preds.clear(); + } + + let mut visited: IndexSet = IndexSet::new(); + + fn visit( + hir: &mut HIR, + block_id: BlockId, + prev_block_id: Option, + visited: &mut IndexSet, + ) { + // Add predecessor + if let Some(prev_id) = prev_block_id { + if let Some(block) = hir.blocks.get_mut(&block_id) { + block.preds.insert(prev_id); + } else { + return; + } + } + + if visited.contains(&block_id) { + return; + } + visited.insert(block_id); + + // Get successors before mutating + let successors = if let Some(block) = hir.blocks.get(&block_id) { + each_terminal_successor(&block.terminal) + } else { + return; + }; + + for successor in successors { + visit(hir, successor, Some(block_id), visited); + } + } + + visit(hir, hir.entry, None, &mut visited); +} + +// --------------------------------------------------------------------------- +// Public helper functions +// --------------------------------------------------------------------------- + +/// Create a temporary Place with a fresh identifier allocated in the arena. +pub fn create_temporary_place(env: &mut Environment, loc: Option) -> Place { + let id = env.next_identifier_id(); + // Update the loc on the allocated identifier + env.identifiers[id.0 as usize].loc = loc; + Place { + identifier: id, + reactive: false, + effect: Effect::Unknown, + loc: None, + } +} diff --git a/crates/react_compiler_lowering/src/identifier_loc_index.rs b/crates/react_compiler_lowering/src/identifier_loc_index.rs new file mode 100644 index 000000000000..c926dfb9b010 --- /dev/null +++ b/crates/react_compiler_lowering/src/identifier_loc_index.rs @@ -0,0 +1,196 @@ +//! Builds an index mapping identifier byte offsets to source locations. +//! +//! Walks the function's AST to collect `(start, SourceLocation, is_jsx)` for +//! every Identifier and JSXIdentifier node. This replaces the `referenceLocs` +//! and `jsxReferencePositions` fields that were previously serialized from JS. + +use std::collections::HashMap; + +use react_compiler_ast::{ + expressions::*, + jsx::{JSXIdentifier, JSXOpeningElement}, + scope::{ScopeId, ScopeInfo}, + statements::FunctionDeclaration, + visitor::{AstWalker, Visitor}, +}; +use react_compiler_hir::SourceLocation; + +use crate::FunctionNode; + +/// Source location and whether the identifier is a JSXIdentifier. +pub struct IdentifierLocEntry { + pub loc: SourceLocation, + pub is_jsx: bool, + /// For JSX identifiers that are the root name of a JSXOpeningElement, + /// stores the JSXOpeningElement's loc (which spans the full tag). + /// This matches the TS behavior where `handleMaybeDependency` receives + /// the JSXOpeningElement path and uses `path.node.loc`. + pub opening_element_loc: Option, + /// True if this identifier is the name of a function/class declaration + /// (not an expression reference). Used by `gather_captured_context` to + /// skip non-expression positions, matching the TS behavior where the + /// Expression visitor doesn't visit declaration names. + pub is_declaration_name: bool, +} + +/// Index mapping byte offset → (SourceLocation, is_jsx) for all Identifier +/// and JSXIdentifier nodes in a function's AST. +pub type IdentifierLocIndex = HashMap; + +struct IdentifierLocVisitor { + index: IdentifierLocIndex, + /// Tracks the current JSXOpeningElement's loc while walking its name. + current_opening_element_loc: Option, +} + +fn convert_loc(loc: &react_compiler_ast::common::SourceLocation) -> SourceLocation { + SourceLocation { + start: react_compiler_hir::Position { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: react_compiler_hir::Position { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + } +} + +impl IdentifierLocVisitor { + fn insert_identifier(&mut self, node: &Identifier, is_declaration_name: bool) { + if let (Some(start), Some(loc)) = (node.base.start, &node.base.loc) { + self.index.insert( + start, + IdentifierLocEntry { + loc: convert_loc(loc), + is_jsx: false, + opening_element_loc: None, + is_declaration_name, + }, + ); + } + } +} + +impl<'ast> Visitor<'ast> for IdentifierLocVisitor { + fn enter_identifier(&mut self, node: &'ast Identifier, _scope_stack: &[ScopeId]) { + self.insert_identifier(node, false); + } + + fn enter_jsx_identifier(&mut self, node: &'ast JSXIdentifier, _scope_stack: &[ScopeId]) { + if let (Some(start), Some(loc)) = (node.base.start, &node.base.loc) { + self.index.insert( + start, + IdentifierLocEntry { + loc: convert_loc(loc), + is_jsx: true, + opening_element_loc: self.current_opening_element_loc.clone(), + is_declaration_name: false, + }, + ); + } + } + + fn enter_jsx_opening_element( + &mut self, + node: &'ast JSXOpeningElement, + _scope_stack: &[ScopeId], + ) { + self.current_opening_element_loc = node.base.loc.as_ref().map(|loc| convert_loc(loc)); + } + + fn leave_jsx_opening_element( + &mut self, + _node: &'ast JSXOpeningElement, + _scope_stack: &[ScopeId], + ) { + self.current_opening_element_loc = None; + } + + // Visit function/class declaration and expression name identifiers, + // which are not walked by the generic walker (to avoid affecting + // other Visitor consumers like find_context_identifiers). + fn enter_function_declaration( + &mut self, + node: &'ast FunctionDeclaration, + _scope_stack: &[ScopeId], + ) { + if let Some(id) = &node.id { + self.insert_identifier(id, true); + } + } + + fn enter_function_expression( + &mut self, + node: &'ast FunctionExpression, + _scope_stack: &[ScopeId], + ) { + if let Some(id) = &node.id { + self.insert_identifier(id, true); + } + } +} + +/// Build an index of all Identifier and JSXIdentifier positions in a function's +/// AST. +pub fn build_identifier_loc_index( + func: &FunctionNode<'_>, + scope_info: &ScopeInfo, +) -> IdentifierLocIndex { + let func_start = match func { + FunctionNode::FunctionDeclaration(d) => d.base.start.unwrap_or(0), + FunctionNode::FunctionExpression(e) => e.base.start.unwrap_or(0), + FunctionNode::ArrowFunctionExpression(a) => a.base.start.unwrap_or(0), + }; + let func_scope = scope_info + .node_to_scope + .get(&func_start) + .copied() + .unwrap_or(scope_info.program_scope); + + let mut visitor = IdentifierLocVisitor { + index: HashMap::new(), + current_opening_element_loc: None, + }; + let mut walker = AstWalker::with_initial_scope(scope_info, func_scope); + + // Visit the top-level function's own name identifier (if any), + // since the walker only walks params + body, not the function node itself. + match func { + FunctionNode::FunctionDeclaration(d) => { + if let Some(id) = &d.id { + visitor.enter_identifier(id, &[]); + } + for param in &d.params { + walker.walk_pattern(&mut visitor, param); + } + walker.walk_block_statement(&mut visitor, &d.body); + } + FunctionNode::FunctionExpression(e) => { + if let Some(id) = &e.id { + visitor.enter_identifier(id, &[]); + } + for param in &e.params { + walker.walk_pattern(&mut visitor, param); + } + walker.walk_block_statement(&mut visitor, &e.body); + } + FunctionNode::ArrowFunctionExpression(a) => { + for param in &a.params { + walker.walk_pattern(&mut visitor, param); + } + match a.body.as_ref() { + ArrowFunctionBody::BlockStatement(block) => { + walker.walk_block_statement(&mut visitor, block); + } + ArrowFunctionBody::Expression(expr) => { + walker.walk_expression(&mut visitor, expr); + } + } + } + } + + visitor.index +} diff --git a/crates/react_compiler_lowering/src/lib.rs b/crates/react_compiler_lowering/src/lib.rs new file mode 100644 index 000000000000..33bbb9ff1570 --- /dev/null +++ b/crates/react_compiler_lowering/src/lib.rs @@ -0,0 +1,46 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod build_hir; +pub mod find_context_identifiers; +pub mod hir_builder; +pub mod identifier_loc_index; + +use react_compiler_ast::{ + expressions::{ArrowFunctionExpression, FunctionExpression}, + statements::FunctionDeclaration, +}; +use react_compiler_hir::BindingKind; + +/// Convert AST binding kind to HIR binding kind. +pub fn convert_binding_kind(kind: &react_compiler_ast::scope::BindingKind) -> BindingKind { + match kind { + react_compiler_ast::scope::BindingKind::Var => BindingKind::Var, + react_compiler_ast::scope::BindingKind::Let => BindingKind::Let, + react_compiler_ast::scope::BindingKind::Const => BindingKind::Const, + react_compiler_ast::scope::BindingKind::Param => BindingKind::Param, + react_compiler_ast::scope::BindingKind::Module => BindingKind::Module, + react_compiler_ast::scope::BindingKind::Hoisted => BindingKind::Hoisted, + react_compiler_ast::scope::BindingKind::Local => BindingKind::Local, + react_compiler_ast::scope::BindingKind::Unknown => BindingKind::Unknown, + } +} + +/// Represents a reference to a function AST node for lowering. +/// Analogous to TS's `NodePath` / `BabelFn`. +pub enum FunctionNode<'a> { + FunctionDeclaration(&'a FunctionDeclaration), + FunctionExpression(&'a FunctionExpression), + ArrowFunctionExpression(&'a ArrowFunctionExpression), +} + +// The main lower() function - delegates to build_hir +pub use build_hir::lower; +// Re-export post-build helper functions used by optimization passes +pub use hir_builder::{ + create_temporary_place, get_reverse_postordered_blocks, mark_instruction_ids, + mark_predecessors, remove_dead_do_while_statements, remove_unnecessary_try_catch, + remove_unreachable_for_updates, +}; +pub use react_compiler_hir::visitors::{each_terminal_successor, terminal_fallthrough}; diff --git a/crates/react_compiler_optimization/Cargo.toml b/crates/react_compiler_optimization/Cargo.toml new file mode 100644 index 000000000000..3050549f42e2 --- /dev/null +++ b/crates/react_compiler_optimization/Cargo.toml @@ -0,0 +1,14 @@ +[package] +description = "Vendored React Compiler optimization passes from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_optimization" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } +react_compiler_lowering = { path = "../react_compiler_lowering" } +react_compiler_ssa = { path = "../react_compiler_ssa" } +indexmap = { workspace = true } diff --git a/crates/react_compiler_optimization/src/constant_propagation.rs b/crates/react_compiler_optimization/src/constant_propagation.rs new file mode 100644 index 000000000000..8963e85c8355 --- /dev/null +++ b/crates/react_compiler_optimization/src/constant_propagation.rs @@ -0,0 +1,1137 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Constant propagation/folding pass. +//! +//! Applies Sparse Conditional Constant Propagation to the given function. +//! We use abstract interpretation to record known constant values for +//! identifiers, with lack of a value indicating that the identifier does not +//! have a known constant value. +//! +//! Instructions which can be compile-time evaluated *and* whose operands are +//! known constants are replaced with the resulting constant value. +//! +//! This pass also exploits SSA form, tracking constant values of local +//! variables. For example, in `let x = 4; let y = x + 1` we know that `x = 4` +//! in the binary expression and can replace it with `Constant 5`. +//! +//! This pass also visits conditionals (currently only IfTerminal) and can prune +//! unreachable branches when the condition is a known truthy/falsey constant. +//! The pass uses fixpoint iteration, looping until no additional updates can be +//! performed. +//! +//! Analogous to TS `Optimization/ConstantPropagation.ts`. + +use std::collections::HashMap; + +use react_compiler_hir::{ + environment::Environment, BinaryOperator, BlockKind, FloatValue, FunctionId, GotoVariant, + HirFunction, IdentifierId, InstructionValue, NonLocalBinding, Phi, Place, PrimitiveValue, + PropertyLiteral, SourceLocation, Terminal, UnaryOperator, UpdateOperator, +}; +use react_compiler_lowering::{ + get_reverse_postordered_blocks, mark_instruction_ids, mark_predecessors, + remove_dead_do_while_statements, remove_unnecessary_try_catch, remove_unreachable_for_updates, +}; +use react_compiler_ssa::enter_ssa::placeholder_function; + +use crate::merge_consecutive_blocks::merge_consecutive_blocks; + +// ============================================================================= +// Constant type — mirrors TS `type Constant = Primitive | LoadGlobal` +// The loc is preserved so that when we replace an instruction value with the +// constant, we use the loc from the original definition site (matching TS). +// ============================================================================= + +#[derive(Debug, Clone)] +enum Constant { + Primitive { + value: PrimitiveValue, + loc: Option, + }, + LoadGlobal { + binding: NonLocalBinding, + loc: Option, + }, +} + +impl Constant { + fn into_instruction_value(self) -> InstructionValue { + match self { + Constant::Primitive { value, loc } => InstructionValue::Primitive { value, loc }, + Constant::LoadGlobal { binding, loc } => InstructionValue::LoadGlobal { binding, loc }, + } + } +} + +/// Map of known constant values. Uses HashMap (not IndexMap) since iteration +/// order does not affect correctness — this map is only used for lookups. +type Constants = HashMap; + +// ============================================================================= +// Public entry point +// ============================================================================= + +pub fn constant_propagation(func: &mut HirFunction, env: &mut Environment) { + let mut constants: Constants = HashMap::new(); + constant_propagation_impl(func, env, &mut constants); +} + +fn constant_propagation_impl( + func: &mut HirFunction, + env: &mut Environment, + constants: &mut Constants, +) { + loop { + let have_terminals_changed = apply_constant_propagation(func, env, constants); + if !have_terminals_changed { + break; + } + /* + * If terminals have changed then blocks may have become newly unreachable. + * Re-run minification of the graph (incl reordering instruction ids) + */ + func.body.blocks = get_reverse_postordered_blocks(&func.body, &func.instructions); + remove_unreachable_for_updates(&mut func.body); + remove_dead_do_while_statements(&mut func.body); + remove_unnecessary_try_catch(&mut func.body); + mark_instruction_ids(&mut func.body, &mut func.instructions); + mark_predecessors(&mut func.body); + + // Now that predecessors are updated, prune phi operands that can never be + // reached + for (_block_id, block) in func.body.blocks.iter_mut() { + for phi in &mut block.phis { + phi.operands + .retain(|pred, _operand| block.preds.contains(pred)); + } + } + + /* + * By removing some phi operands, there may be phis that were not previously + * redundant but now are + */ + react_compiler_ssa::eliminate_redundant_phi(func, env); + + /* + * Finally, merge together any blocks that are now guaranteed to execute + * consecutively + */ + merge_consecutive_blocks(func, &mut env.functions); + + // TODO: port assertConsistentIdentifiers(fn) and + // assertTerminalSuccessorsExist(fn) from TS HIR validation. + // These are debug assertions that verify structural invariants + // after the CFG cleanup helpers run. + } +} + +fn apply_constant_propagation( + func: &mut HirFunction, + env: &mut Environment, + constants: &mut Constants, +) -> bool { + let mut has_changes = false; + + let block_ids: Vec<_> = func.body.blocks.keys().copied().collect(); + for block_id in block_ids { + let block = &func.body.blocks[&block_id]; + + // Initialize phi values if all operands have the same known constant value + let phi_updates: Vec<(IdentifierId, Constant)> = block + .phis + .iter() + .filter_map(|phi| { + let value = evaluate_phi(phi, constants)?; + Some((phi.place.identifier, value)) + }) + .collect(); + for (id, value) in phi_updates { + constants.insert(id, value); + } + + let block = &func.body.blocks[&block_id]; + let instr_ids = block.instructions.clone(); + let block_kind = block.kind; + let instr_count = instr_ids.len(); + + for (i, instr_id) in instr_ids.iter().enumerate() { + if block_kind == BlockKind::Sequence && i == instr_count - 1 { + /* + * evaluating the last value of a value block can break order of evaluation, + * skip these instructions + */ + continue; + } + let result = evaluate_instruction(constants, func, env, *instr_id); + if let Some(value) = result { + let lvalue_id = func.instructions[instr_id.0 as usize].lvalue.identifier; + constants.insert(lvalue_id, value); + } + } + + let block = &func.body.blocks[&block_id]; + match &block.terminal { + Terminal::If { + test, + consequent, + alternate, + id, + loc, + .. + } => { + let test_value = read(constants, test); + if let Some(Constant::Primitive { + value: ref prim, .. + }) = test_value + { + has_changes = true; + let target_block_id = if is_truthy(prim) { + *consequent + } else { + *alternate + }; + let terminal = Terminal::Goto { + variant: GotoVariant::Break, + block: target_block_id, + id: *id, + loc: *loc, + }; + func.body.blocks.get_mut(&block_id).unwrap().terminal = terminal; + } + } + Terminal::Unsupported { .. } + | Terminal::Unreachable { .. } + | Terminal::Throw { .. } + | Terminal::Return { .. } + | Terminal::Goto { .. } + | Terminal::Branch { .. } + | Terminal::Switch { .. } + | Terminal::DoWhile { .. } + | Terminal::While { .. } + | Terminal::For { .. } + | Terminal::ForOf { .. } + | Terminal::ForIn { .. } + | Terminal::Logical { .. } + | Terminal::Ternary { .. } + | Terminal::Optional { .. } + | Terminal::Label { .. } + | Terminal::Sequence { .. } + | Terminal::MaybeThrow { .. } + | Terminal::Try { .. } + | Terminal::Scope { .. } + | Terminal::PrunedScope { .. } => { + // no-op + } + } + } + + has_changes +} + +// ============================================================================= +// Phi evaluation +// ============================================================================= + +fn evaluate_phi(phi: &Phi, constants: &Constants) -> Option { + let mut value: Option = None; + for (_pred, operand) in &phi.operands { + let operand_value = constants.get(&operand.identifier)?; + + match &value { + None => { + // first iteration of the loop + value = Some(operand_value.clone()); + continue; + } + Some(current) => match (current, operand_value) { + (Constant::Primitive { value: a, .. }, Constant::Primitive { value: b, .. }) => { + // Use JS strict equality semantics: NaN !== NaN + if !js_strict_equal(a, b) { + return None; + } + } + ( + Constant::LoadGlobal { binding: a, .. }, + Constant::LoadGlobal { binding: b, .. }, + ) => { + // different global values, can't constant propagate + if a.name() != b.name() { + return None; + } + } + // found different kinds of constants, can't constant propagate + (Constant::Primitive { .. }, Constant::LoadGlobal { .. }) + | (Constant::LoadGlobal { .. }, Constant::Primitive { .. }) => { + return None; + } + }, + } + } + value +} + +// ============================================================================= +// Instruction evaluation +// ============================================================================= + +fn evaluate_instruction( + constants: &mut Constants, + func: &mut HirFunction, + env: &mut Environment, + instr_id: react_compiler_hir::InstructionId, +) -> Option { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::Primitive { value, loc } => Some(Constant::Primitive { + value: value.clone(), + loc: *loc, + }), + InstructionValue::LoadGlobal { binding, loc } => Some(Constant::LoadGlobal { + binding: binding.clone(), + loc: *loc, + }), + InstructionValue::ComputedLoad { + object, + property, + loc, + } => { + let prop_value = read(constants, property); + if let Some(Constant::Primitive { + value: ref prim, .. + }) = prop_value + { + match prim { + PrimitiveValue::String(s) if is_valid_identifier(s) => { + let object = object.clone(); + let loc = *loc; + let new_property = PropertyLiteral::String(s.clone()); + func.instructions[instr_id.0 as usize].value = + InstructionValue::PropertyLoad { + object, + property: new_property, + loc, + }; + } + PrimitiveValue::Number(n) => { + let object = object.clone(); + let loc = *loc; + let new_property = PropertyLiteral::Number(*n); + func.instructions[instr_id.0 as usize].value = + InstructionValue::PropertyLoad { + object, + property: new_property, + loc, + }; + } + PrimitiveValue::Null + | PrimitiveValue::Undefined + | PrimitiveValue::Boolean(_) + | PrimitiveValue::String(_) => {} + } + } + None + } + InstructionValue::ComputedStore { + object, + property, + value, + loc, + } => { + let prop_value = read(constants, property); + if let Some(Constant::Primitive { + value: ref prim, .. + }) = prop_value + { + match prim { + PrimitiveValue::String(s) if is_valid_identifier(s) => { + let object = object.clone(); + let store_value = value.clone(); + let loc = *loc; + let new_property = PropertyLiteral::String(s.clone()); + func.instructions[instr_id.0 as usize].value = + InstructionValue::PropertyStore { + object, + property: new_property, + value: store_value, + loc, + }; + } + PrimitiveValue::Number(n) => { + let object = object.clone(); + let store_value = value.clone(); + let loc = *loc; + let new_property = PropertyLiteral::Number(*n); + func.instructions[instr_id.0 as usize].value = + InstructionValue::PropertyStore { + object, + property: new_property, + value: store_value, + loc, + }; + } + PrimitiveValue::Null + | PrimitiveValue::Undefined + | PrimitiveValue::Boolean(_) + | PrimitiveValue::String(_) => {} + } + } + None + } + InstructionValue::PostfixUpdate { + lvalue, + operation, + value, + loc, + } => { + let previous = read(constants, value); + if let Some(Constant::Primitive { + value: PrimitiveValue::Number(n), + loc: prev_loc, + }) = previous + { + let prev_val = n.value(); + let next_val = match operation { + UpdateOperator::Increment => prev_val + 1.0, + UpdateOperator::Decrement => prev_val - 1.0, + }; + // Store the updated value for the lvalue + let lvalue_id = lvalue.identifier; + constants.insert( + lvalue_id, + Constant::Primitive { + value: PrimitiveValue::Number(FloatValue::new(next_val)), + loc: *loc, + }, + ); + // But return the value prior to the update (preserving its original loc) + return Some(Constant::Primitive { + value: PrimitiveValue::Number(n), + loc: prev_loc, + }); + } + None + } + InstructionValue::PrefixUpdate { + lvalue, + operation, + value, + loc, + } => { + let previous = read(constants, value); + if let Some(Constant::Primitive { + value: PrimitiveValue::Number(n), + .. + }) = previous + { + let prev_val = n.value(); + let next_val = match operation { + UpdateOperator::Increment => prev_val + 1.0, + UpdateOperator::Decrement => prev_val - 1.0, + }; + let result = Constant::Primitive { + value: PrimitiveValue::Number(FloatValue::new(next_val)), + loc: *loc, + }; + // Store and return the updated value + let lvalue_id = lvalue.identifier; + constants.insert(lvalue_id, result.clone()); + return Some(result); + } + None + } + InstructionValue::UnaryExpression { + operator, + value, + loc, + } => match operator { + UnaryOperator::Not => { + let operand = read(constants, value); + if let Some(Constant::Primitive { + value: ref prim, .. + }) = operand + { + let negated = !is_truthy(prim); + let loc = *loc; + let result = Constant::Primitive { + value: PrimitiveValue::Boolean(negated), + loc, + }; + func.instructions[instr_id.0 as usize].value = InstructionValue::Primitive { + value: PrimitiveValue::Boolean(negated), + loc, + }; + return Some(result); + } + None + } + UnaryOperator::Minus => { + let operand = read(constants, value); + if let Some(Constant::Primitive { + value: PrimitiveValue::Number(n), + .. + }) = operand + { + let negated = n.value() * -1.0; + let loc = *loc; + let result = Constant::Primitive { + value: PrimitiveValue::Number(FloatValue::new(negated)), + loc, + }; + func.instructions[instr_id.0 as usize].value = InstructionValue::Primitive { + value: PrimitiveValue::Number(FloatValue::new(negated)), + loc, + }; + return Some(result); + } + None + } + UnaryOperator::Plus + | UnaryOperator::BitwiseNot + | UnaryOperator::TypeOf + | UnaryOperator::Void => None, + }, + InstructionValue::BinaryExpression { + operator, + left, + right, + loc, + } => { + let lhs_value = read(constants, left); + let rhs_value = read(constants, right); + if let ( + Some(Constant::Primitive { value: lhs, .. }), + Some(Constant::Primitive { value: rhs, .. }), + ) = (&lhs_value, &rhs_value) + { + let result = evaluate_binary_op(*operator, lhs, rhs); + if let Some(ref prim) = result { + let loc = *loc; + func.instructions[instr_id.0 as usize].value = InstructionValue::Primitive { + value: prim.clone(), + loc, + }; + return Some(Constant::Primitive { + value: prim.clone(), + loc, + }); + } + } + None + } + InstructionValue::PropertyLoad { + object, + property, + loc, + } => { + let object_value = read(constants, object); + if let Some(Constant::Primitive { + value: PrimitiveValue::String(ref s), + .. + }) = object_value + { + if let PropertyLiteral::String(prop_name) = property { + if prop_name == "length" { + // Use UTF-16 code unit count to match JS .length semantics + let len = s.encode_utf16().count() as f64; + let loc = *loc; + let result = Constant::Primitive { + value: PrimitiveValue::Number(FloatValue::new(len)), + loc, + }; + func.instructions[instr_id.0 as usize].value = + InstructionValue::Primitive { + value: PrimitiveValue::Number(FloatValue::new(len)), + loc, + }; + return Some(result); + } + } + } + None + } + InstructionValue::TemplateLiteral { + subexprs, + quasis, + loc, + } => { + if subexprs.is_empty() { + // No subexpressions: join all cooked quasis + let mut result_string = String::new(); + for q in quasis { + match &q.cooked { + Some(cooked) => result_string.push_str(cooked), + None => return None, + } + } + let loc = *loc; + let result = Constant::Primitive { + value: PrimitiveValue::String(result_string.clone()), + loc, + }; + func.instructions[instr_id.0 as usize].value = InstructionValue::Primitive { + value: PrimitiveValue::String(result_string), + loc, + }; + return Some(result); + } + + if subexprs.len() != quasis.len() - 1 { + return None; + } + + if quasis.iter().any(|q| q.cooked.is_none()) { + return None; + } + + let mut quasi_index = 0usize; + let mut result_string = quasis[quasi_index].cooked.as_ref().unwrap().clone(); + quasi_index += 1; + + for sub_expr in subexprs { + let sub_expr_value = read(constants, sub_expr); + let sub_prim = match sub_expr_value { + Some(Constant::Primitive { ref value, .. }) => value, + _ => return None, + }; + + let expression_str = match sub_prim { + PrimitiveValue::Null => "null".to_string(), + PrimitiveValue::Boolean(b) => b.to_string(), + PrimitiveValue::Number(n) => js_number_to_string(n.value()), + PrimitiveValue::String(s) => s.clone(), + // TS rejects undefined subexpression values + PrimitiveValue::Undefined => return None, + }; + + let suffix = match &quasis[quasi_index].cooked { + Some(s) => s.clone(), + None => return None, + }; + quasi_index += 1; + + result_string.push_str(&expression_str); + result_string.push_str(&suffix); + } + + let loc = *loc; + let result = Constant::Primitive { + value: PrimitiveValue::String(result_string.clone()), + loc, + }; + func.instructions[instr_id.0 as usize].value = InstructionValue::Primitive { + value: PrimitiveValue::String(result_string), + loc, + }; + Some(result) + } + InstructionValue::LoadLocal { place, .. } => { + let place_value = read(constants, place); + if let Some(ref constant) = place_value { + // Replace the LoadLocal with the constant value (including the constant's + // original loc) + func.instructions[instr_id.0 as usize].value = + constant.clone().into_instruction_value(); + } + place_value + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + let place_value = read(constants, value); + if let Some(ref constant) = place_value { + let lvalue_id = lvalue.place.identifier; + constants.insert(lvalue_id, constant.clone()); + } + place_value + } + InstructionValue::FunctionExpression { lowered_func, .. } => { + let func_id = lowered_func.func; + process_inner_function(func_id, env, constants); + None + } + InstructionValue::ObjectMethod { lowered_func, .. } => { + let func_id = lowered_func.func; + process_inner_function(func_id, env, constants); + None + } + InstructionValue::StartMemoize { deps, .. } => { + if let Some(deps) = deps { + // Two-phase: collect which deps are constant, then mutate + let const_dep_indices: Vec = deps + .iter() + .enumerate() + .filter_map(|(i, dep)| { + if let react_compiler_hir::ManualMemoDependencyRoot::NamedLocal { + value, + .. + } = &dep.root + { + let pv = read(constants, value); + if matches!(pv, Some(Constant::Primitive { .. })) { + return Some(i); + } + } + None + }) + .collect(); + for idx in const_dep_indices { + if let InstructionValue::StartMemoize { + deps: Some(ref mut deps), + .. + } = func.instructions[instr_id.0 as usize].value + { + if let react_compiler_hir::ManualMemoDependencyRoot::NamedLocal { + constant, + .. + } = &mut deps[idx].root + { + *constant = true; + } + } + } + } + None + } + // All other instruction kinds: no constant folding + InstructionValue::LoadContext { .. } + | InstructionValue::DeclareLocal { .. } + | InstructionValue::DeclareContext { .. } + | InstructionValue::StoreContext { .. } + | InstructionValue::Destructure { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::NewExpression { .. } + | InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } + | InstructionValue::TypeCastExpression { .. } + | InstructionValue::JsxExpression { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::JsxFragment { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::PropertyStore { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::StoreGlobal { .. } + | InstructionValue::TaggedTemplateExpression { .. } + | InstructionValue::Await { .. } + | InstructionValue::GetIterator { .. } + | InstructionValue::IteratorNext { .. } + | InstructionValue::NextPropertyOf { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::FinishMemoize { .. } + | InstructionValue::UnsupportedNode { .. } => None, + } +} + +// ============================================================================= +// Inner function processing +// ============================================================================= + +fn process_inner_function(func_id: FunctionId, env: &mut Environment, constants: &mut Constants) { + let mut inner = std::mem::replace( + &mut env.functions[func_id.0 as usize], + placeholder_function(), + ); + constant_propagation_impl(&mut inner, env, constants); + env.functions[func_id.0 as usize] = inner; +} + +// ============================================================================= +// Helper: read constant for a place +// ============================================================================= + +fn read(constants: &Constants, place: &Place) -> Option { + constants.get(&place.identifier).cloned() +} + +// ============================================================================= +// Helper: is_valid_identifier +// ============================================================================= + +/// Check if a string is a valid JavaScript identifier. +/// Supports Unicode identifier characters per ECMAScript spec (ID_Start / +/// ID_Continue). Rejects JS reserved words (matching Babel's +/// `isValidIdentifier` default behavior). +fn is_valid_identifier(s: &str) -> bool { + if s.is_empty() { + return false; + } + let mut chars = s.chars(); + match chars.next() { + Some(c) if is_id_start(c) => {} + _ => return false, + } + if !chars.all(is_id_continue) { + return false; + } + !is_reserved_word(s) +} + +/// JS reserved words that cannot be used as identifiers. +/// Includes keywords, future reserved words, and strict mode reserved words. +fn is_reserved_word(s: &str) -> bool { + matches!( + s, + "break" + | "case" + | "catch" + | "continue" + | "debugger" + | "default" + | "do" + | "else" + | "finally" + | "for" + | "function" + | "if" + | "in" + | "instanceof" + | "new" + | "return" + | "switch" + | "this" + | "throw" + | "try" + | "typeof" + | "var" + | "void" + | "while" + | "with" + | "class" + | "const" + | "enum" + | "export" + | "extends" + | "import" + | "super" + | "implements" + | "interface" + | "let" + | "package" + | "private" + | "protected" + | "public" + | "static" + | "yield" + | "await" + | "delete" + | "null" + | "true" + | "false" + ) +} + +/// Check if a character is valid as the start of a JS identifier (ID_Start + _ +/// + $). +fn is_id_start(c: char) -> bool { + c == '_' || c == '$' || c.is_alphabetic() +} + +/// Check if a character is valid as a continuation of a JS identifier +/// (ID_Continue + $ + \u200C + \u200D). +fn is_id_continue(c: char) -> bool { + c == '$' + || c == '_' + || c.is_alphanumeric() + || c == '\u{200C}' // ZWNJ + || c == '\u{200D}' // ZWJ +} + +// ============================================================================= +// Helper: is_truthy for PrimitiveValue +// ============================================================================= + +fn is_truthy(value: &PrimitiveValue) -> bool { + match value { + PrimitiveValue::Null => false, + PrimitiveValue::Undefined => false, + PrimitiveValue::Boolean(b) => *b, + PrimitiveValue::Number(n) => { + let v = n.value(); + v != 0.0 && !v.is_nan() + } + PrimitiveValue::String(s) => !s.is_empty(), + } +} + +// ============================================================================= +// Binary operation evaluation +// ============================================================================= + +fn evaluate_binary_op( + operator: BinaryOperator, + lhs: &PrimitiveValue, + rhs: &PrimitiveValue, +) -> Option { + match operator { + BinaryOperator::Add => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => Some(PrimitiveValue::Number( + FloatValue::new(l.value() + r.value()), + )), + (PrimitiveValue::String(l), PrimitiveValue::String(r)) => { + let mut s = l.clone(); + s.push_str(r); + Some(PrimitiveValue::String(s)) + } + _ => None, + }, + BinaryOperator::Subtract => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => Some(PrimitiveValue::Number( + FloatValue::new(l.value() - r.value()), + )), + _ => None, + }, + BinaryOperator::Multiply => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => Some(PrimitiveValue::Number( + FloatValue::new(l.value() * r.value()), + )), + _ => None, + }, + BinaryOperator::Divide => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => Some(PrimitiveValue::Number( + FloatValue::new(l.value() / r.value()), + )), + _ => None, + }, + BinaryOperator::Modulo => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => Some(PrimitiveValue::Number( + FloatValue::new(l.value() % r.value()), + )), + _ => None, + }, + BinaryOperator::Exponent => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => Some(PrimitiveValue::Number( + FloatValue::new(l.value().powf(r.value())), + )), + _ => None, + }, + BinaryOperator::BitwiseOr => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + let result = js_to_int32(l.value()) | js_to_int32(r.value()); + Some(PrimitiveValue::Number(FloatValue::new(result as f64))) + } + _ => None, + }, + BinaryOperator::BitwiseAnd => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + let result = js_to_int32(l.value()) & js_to_int32(r.value()); + Some(PrimitiveValue::Number(FloatValue::new(result as f64))) + } + _ => None, + }, + BinaryOperator::BitwiseXor => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + let result = js_to_int32(l.value()) ^ js_to_int32(r.value()); + Some(PrimitiveValue::Number(FloatValue::new(result as f64))) + } + _ => None, + }, + BinaryOperator::ShiftLeft => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + let result = js_to_int32(l.value()) << (js_to_uint32(r.value()) & 0x1f); + Some(PrimitiveValue::Number(FloatValue::new(result as f64))) + } + _ => None, + }, + BinaryOperator::ShiftRight => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + let result = js_to_int32(l.value()) >> (js_to_uint32(r.value()) & 0x1f); + Some(PrimitiveValue::Number(FloatValue::new(result as f64))) + } + _ => None, + }, + BinaryOperator::UnsignedShiftRight => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + let result = js_to_uint32(l.value()) >> (js_to_uint32(r.value()) & 0x1f); + Some(PrimitiveValue::Number(FloatValue::new(result as f64))) + } + _ => None, + }, + BinaryOperator::LessThan => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + Some(PrimitiveValue::Boolean(l.value() < r.value())) + } + _ => None, + }, + BinaryOperator::LessEqual => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + Some(PrimitiveValue::Boolean(l.value() <= r.value())) + } + _ => None, + }, + BinaryOperator::GreaterThan => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + Some(PrimitiveValue::Boolean(l.value() > r.value())) + } + _ => None, + }, + BinaryOperator::GreaterEqual => match (lhs, rhs) { + (PrimitiveValue::Number(l), PrimitiveValue::Number(r)) => { + Some(PrimitiveValue::Boolean(l.value() >= r.value())) + } + _ => None, + }, + BinaryOperator::StrictEqual => Some(PrimitiveValue::Boolean(js_strict_equal(lhs, rhs))), + BinaryOperator::StrictNotEqual => Some(PrimitiveValue::Boolean(!js_strict_equal(lhs, rhs))), + BinaryOperator::Equal => Some(PrimitiveValue::Boolean(js_abstract_equal(lhs, rhs))), + BinaryOperator::NotEqual => Some(PrimitiveValue::Boolean(!js_abstract_equal(lhs, rhs))), + BinaryOperator::In | BinaryOperator::InstanceOf => None, + } +} + +// ============================================================================= +// JavaScript equality semantics +// ============================================================================= + +fn js_strict_equal(lhs: &PrimitiveValue, rhs: &PrimitiveValue) -> bool { + match (lhs, rhs) { + (PrimitiveValue::Null, PrimitiveValue::Null) => true, + (PrimitiveValue::Undefined, PrimitiveValue::Undefined) => true, + (PrimitiveValue::Boolean(a), PrimitiveValue::Boolean(b)) => a == b, + (PrimitiveValue::Number(a), PrimitiveValue::Number(b)) => { + let av = a.value(); + let bv = b.value(); + // NaN !== NaN in JS + if av.is_nan() || bv.is_nan() { + return false; + } + av == bv + } + (PrimitiveValue::String(a), PrimitiveValue::String(b)) => a == b, + // Different types => false + _ => false, + } +} + +/// Convert a string to a number using JS `ToNumber` semantics. +/// In JS: `""` → 0, `" "` → 0, `" 42 "` → 42, `"0x1A"` → 26, `"Infinity"` → +/// Infinity. +fn js_to_number(s: &str) -> f64 { + let trimmed = s.trim(); + if trimmed.is_empty() { + return 0.0; + } + if trimmed == "Infinity" || trimmed == "+Infinity" { + return f64::INFINITY; + } + if trimmed == "-Infinity" { + return f64::NEG_INFINITY; + } + // Handle hex literals (0x/0X) + if trimmed.starts_with("0x") || trimmed.starts_with("0X") { + return match u64::from_str_radix(&trimmed[2..], 16) { + Ok(v) => v as f64, + Err(_) => f64::NAN, + }; + } + // Handle octal literals (0o/0O) + if trimmed.starts_with("0o") || trimmed.starts_with("0O") { + return match u64::from_str_radix(&trimmed[2..], 8) { + Ok(v) => v as f64, + Err(_) => f64::NAN, + }; + } + // Handle binary literals (0b/0B) + if trimmed.starts_with("0b") || trimmed.starts_with("0B") { + return match u64::from_str_radix(&trimmed[2..], 2) { + Ok(v) => v as f64, + Err(_) => f64::NAN, + }; + } + trimmed.parse::().unwrap_or(f64::NAN) +} + +fn js_abstract_equal(lhs: &PrimitiveValue, rhs: &PrimitiveValue) -> bool { + match (lhs, rhs) { + (PrimitiveValue::Null, PrimitiveValue::Null) => true, + (PrimitiveValue::Undefined, PrimitiveValue::Undefined) => true, + (PrimitiveValue::Null, PrimitiveValue::Undefined) + | (PrimitiveValue::Undefined, PrimitiveValue::Null) => true, + (PrimitiveValue::Boolean(a), PrimitiveValue::Boolean(b)) => a == b, + (PrimitiveValue::Number(a), PrimitiveValue::Number(b)) => { + let av = a.value(); + let bv = b.value(); + if av.is_nan() || bv.is_nan() { + return false; + } + av == bv + } + (PrimitiveValue::String(a), PrimitiveValue::String(b)) => a == b, + // Cross-type coercions for primitives + (PrimitiveValue::Number(n), PrimitiveValue::String(s)) + | (PrimitiveValue::String(s), PrimitiveValue::Number(n)) => { + // String is coerced to number using JS ToNumber semantics + let sv = js_to_number(s); + let nv = n.value(); + if nv.is_nan() || sv.is_nan() { + false + } else { + nv == sv + } + } + (PrimitiveValue::Boolean(b), other) => { + let num = if *b { 1.0 } else { 0.0 }; + js_abstract_equal(&PrimitiveValue::Number(FloatValue::new(num)), other) + } + (other, PrimitiveValue::Boolean(b)) => { + let num = if *b { 1.0 } else { 0.0 }; + js_abstract_equal(other, &PrimitiveValue::Number(FloatValue::new(num))) + } + // null/undefined vs number/string => false + _ => false, + } +} + +// ============================================================================= +// JavaScript Number.toString() approximation +// ============================================================================= + +/// ECMAScript ToInt32: convert f64 to i32 with modular (wrapping) semantics. +fn js_to_int32(n: f64) -> i32 { + if n.is_nan() || n.is_infinite() || n == 0.0 { + return 0; + } + // Truncate, then wrap to 32 bits + let int64 = (n.trunc() as i64) & 0xffffffff; + // Reinterpret as signed i32 + if int64 >= 0x80000000 { + (int64 as u32) as i32 + } else { + int64 as i32 + } +} + +/// ECMAScript ToUint32: convert f64 to u32 with modular (wrapping) semantics. +fn js_to_uint32(n: f64) -> u32 { + js_to_int32(n) as u32 +} + +/// Approximate ECMAScript Number::toString(). Handles special values and +/// tries to match JS formatting for common cases. Uses Rust's default +/// float formatting which may diverge from JS for exotic values +/// (e.g., very large/small numbers near the exponential notation threshold). +fn js_number_to_string(n: f64) -> String { + if n.is_nan() { + return "NaN".to_string(); + } + if n.is_infinite() { + return if n > 0.0 { + "Infinity".to_string() + } else { + "-Infinity".to_string() + }; + } + if n == 0.0 { + return "0".to_string(); + } + // For integers that fit, use integer formatting (no decimal point) + if n.fract() == 0.0 && n.abs() < (i64::MAX as f64) { + return format!("{}", n as i64); + } + // Default: use Rust's float formatting + // This may diverge from JS for edge cases around exponential notation + // thresholds + format!("{}", n) +} diff --git a/crates/react_compiler_optimization/src/dead_code_elimination.rs b/crates/react_compiler_optimization/src/dead_code_elimination.rs new file mode 100644 index 000000000000..1b660bacfe5e --- /dev/null +++ b/crates/react_compiler_optimization/src/dead_code_elimination.rs @@ -0,0 +1,427 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Dead code elimination pass. +//! +//! Eliminates instructions whose values are unused, reducing generated code +//! size. Performs mark-and-sweep analysis to identify and remove dead code +//! while preserving side effects and program semantics. +//! +//! Ported from TypeScript `src/Optimization/DeadCodeElimination.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::{Environment, OutputMode}, + object_shape::HookKind, + visitors, ArrayPatternElement, BlockId, BlockKind, HirFunction, IdentifierId, InstructionKind, + InstructionValue, ObjectPropertyOrSpread, Pattern, +}; + +/// Implements dead-code elimination, eliminating instructions whose values are +/// unused. +/// +/// Note that unreachable blocks are already pruned during HIR construction. +/// +/// Corresponds to TS `deadCodeElimination(fn: HIRFunction): void`. +pub fn dead_code_elimination(func: &mut HirFunction, env: &Environment) { + // Phase 1: Find/mark all referenced identifiers + let state = find_referenced_identifiers(func, env); + + // Phase 2: Prune / sweep unreferenced identifiers and instructions + // Collect instructions to rewrite (two-phase: collect then apply to avoid + // borrow conflicts) + let mut instructions_to_rewrite: Vec = Vec::new(); + + for (_block_id, block) in &mut func.body.blocks { + // Remove unused phi nodes + block + .phis + .retain(|phi| is_id_or_name_used(&state, &env.identifiers, phi.place.identifier)); + + // Remove instructions with unused lvalues + block.instructions.retain(|instr_id| { + let instr = &func.instructions[instr_id.0 as usize]; + is_id_or_name_used(&state, &env.identifiers, instr.lvalue.identifier) + }); + + // Collect instructions that need rewriting (not the block value) + let retained_count = block.instructions.len(); + for i in 0..retained_count { + let is_block_value = block.kind != BlockKind::Block && i == retained_count - 1; + if !is_block_value { + instructions_to_rewrite.push(block.instructions[i]); + } + } + } + + // Apply rewrites + for instr_id in instructions_to_rewrite { + rewrite_instruction(func, instr_id, &state, env); + } + + // Remove unused context variables + func.context + .retain(|ctx_var| is_id_or_name_used(&state, &env.identifiers, ctx_var.identifier)); +} + +/// State for tracking referenced identifiers during mark phase. +struct State { + /// SSA-specific usages (by IdentifierId) + identifiers: HashSet, + /// Named variable usages (any version) + named: HashSet, +} + +impl State { + fn new() -> Self { + State { + identifiers: HashSet::new(), + named: HashSet::new(), + } + } + + fn count(&self) -> usize { + self.identifiers.len() + } +} + +/// Mark an identifier as being referenced (not dead code). +fn reference( + state: &mut State, + identifiers: &[react_compiler_hir::Identifier], + identifier_id: IdentifierId, +) { + state.identifiers.insert(identifier_id); + let ident = &identifiers[identifier_id.0 as usize]; + if let Some(ref name) = ident.name { + state.named.insert(name.value().to_string()); + } +} + +/// Check if any version of the given identifier is used somewhere. +/// Checks both the specific SSA id and (for named identifiers) any usage of +/// that name. +fn is_id_or_name_used( + state: &State, + identifiers: &[react_compiler_hir::Identifier], + identifier_id: IdentifierId, +) -> bool { + if state.identifiers.contains(&identifier_id) { + return true; + } + let ident = &identifiers[identifier_id.0 as usize]; + if let Some(ref name) = ident.name { + state.named.contains(name.value()) + } else { + false + } +} + +/// Check if this specific SSA id is used. +fn is_id_used(state: &State, identifier_id: IdentifierId) -> bool { + state.identifiers.contains(&identifier_id) +} + +/// Phase 1: Find all referenced identifiers via fixed-point iteration. +fn find_referenced_identifiers(func: &HirFunction, env: &Environment) -> State { + let has_loop = has_back_edge(func); + // Collect block ids in reverse order (postorder - successors before + // predecessors) + let reversed_block_ids: Vec = func.body.blocks.keys().rev().copied().collect(); + + let mut state = State::new(); + let mut size; + + loop { + size = state.count(); + + for &block_id in &reversed_block_ids { + let block = &func.body.blocks[&block_id]; + + // Mark terminal operands + for place in visitors::each_terminal_operand(&block.terminal) { + reference(&mut state, &env.identifiers, place.identifier); + } + + // Process instructions in reverse order + let instr_count = block.instructions.len(); + for i in (0..instr_count).rev() { + let instr_id = block.instructions[i]; + let instr = &func.instructions[instr_id.0 as usize]; + + let is_block_value = block.kind != BlockKind::Block && i == instr_count - 1; + + if is_block_value { + // Last instr of a value block is never eligible for pruning + reference(&mut state, &env.identifiers, instr.lvalue.identifier); + for place in visitors::each_instruction_value_operand(&instr.value, env) { + reference(&mut state, &env.identifiers, place.identifier); + } + } else if is_id_or_name_used(&state, &env.identifiers, instr.lvalue.identifier) + || !pruneable_value(&instr.value, &state, env) + { + reference(&mut state, &env.identifiers, instr.lvalue.identifier); + + if let InstructionValue::StoreLocal { lvalue, value, .. } = &instr.value { + // If this is a Let/Const declaration, mark the initializer as referenced + // only if the SSA'd lval is also referenced + if lvalue.kind == InstructionKind::Reassign + || is_id_used(&state, lvalue.place.identifier) + { + reference(&mut state, &env.identifiers, value.identifier); + } + } else { + for place in visitors::each_instruction_value_operand(&instr.value, env) { + reference(&mut state, &env.identifiers, place.identifier); + } + } + } + } + + // Mark phi operands if phi result is used + for phi in &block.phis { + if is_id_or_name_used(&state, &env.identifiers, phi.place.identifier) { + for (_pred, operand) in &phi.operands { + reference(&mut state, &env.identifiers, operand.identifier); + } + } + } + } + + if !(state.count() > size && has_loop) { + break; + } + } + + state +} + +/// Rewrite a retained instruction (destructuring cleanup, StoreLocal -> +/// DeclareLocal). +fn rewrite_instruction( + func: &mut HirFunction, + instr_id: react_compiler_hir::InstructionId, + state: &State, + env: &Environment, +) { + let instr = &mut func.instructions[instr_id.0 as usize]; + + match &mut instr.value { + InstructionValue::Destructure { lvalue, .. } => { + match &mut lvalue.pattern { + Pattern::Array(arr) => { + // For arrays, replace unused items with holes, truncate trailing holes + let mut last_entry_index = 0; + for i in 0..arr.items.len() { + match &arr.items[i] { + ArrayPatternElement::Place(p) => { + if !is_id_or_name_used(state, &env.identifiers, p.identifier) { + arr.items[i] = ArrayPatternElement::Hole; + } else { + last_entry_index = i; + } + } + ArrayPatternElement::Spread(s) => { + if !is_id_or_name_used(state, &env.identifiers, s.place.identifier) + { + arr.items[i] = ArrayPatternElement::Hole; + } else { + last_entry_index = i; + } + } + ArrayPatternElement::Hole => {} + } + } + arr.items.truncate(last_entry_index + 1); + } + Pattern::Object(obj) => { + // For objects, prune unused properties if rest element is unused or absent + let mut next_properties: Option> = None; + for prop in &obj.properties { + match prop { + ObjectPropertyOrSpread::Property(p) => { + if is_id_or_name_used(state, &env.identifiers, p.place.identifier) { + next_properties + .get_or_insert_with(Vec::new) + .push(prop.clone()); + } + } + ObjectPropertyOrSpread::Spread(s) => { + if is_id_or_name_used(state, &env.identifiers, s.place.identifier) { + // Rest element is used, can't prune anything + next_properties = None; + break; + } + } + } + } + if let Some(props) = next_properties { + obj.properties = props; + } + } + } + } + InstructionValue::StoreLocal { + lvalue, + type_annotation, + loc, + .. + } => { + if lvalue.kind != InstructionKind::Reassign + && !is_id_used(state, lvalue.place.identifier) + { + // This is a const/let declaration where the variable is accessed later, + // but where the value is always overwritten before being read. + // Rewrite to DeclareLocal so the initializer value can be DCE'd. + let new_lvalue = lvalue.clone(); + let new_type_annotation = type_annotation.clone(); + let new_loc = *loc; + instr.value = InstructionValue::DeclareLocal { + lvalue: new_lvalue, + type_annotation: new_type_annotation, + loc: new_loc, + }; + } + } + _ => {} + } +} + +/// Returns true if it is safe to prune an instruction with the given value. +fn pruneable_value(value: &InstructionValue, state: &State, env: &Environment) -> bool { + match value { + InstructionValue::DeclareLocal { lvalue, .. } => { + // Declarations are pruneable only if the named variable is never read later + !is_id_or_name_used(state, &env.identifiers, lvalue.place.identifier) + } + InstructionValue::StoreLocal { lvalue, .. } => { + if lvalue.kind == InstructionKind::Reassign { + // Reassignments can be pruned if the specific instance being assigned is never + // read + !is_id_used(state, lvalue.place.identifier) + } else { + // Declarations are pruneable only if the named variable is never read later + !is_id_or_name_used(state, &env.identifiers, lvalue.place.identifier) + } + } + InstructionValue::Destructure { lvalue, .. } => { + let mut is_id_or_name_used_flag = false; + let mut is_id_used_flag = false; + for place in visitors::each_pattern_operand(&lvalue.pattern) { + if is_id_used(state, place.identifier) { + is_id_or_name_used_flag = true; + is_id_used_flag = true; + } else if is_id_or_name_used(state, &env.identifiers, place.identifier) { + is_id_or_name_used_flag = true; + } + } + if lvalue.kind == InstructionKind::Reassign { + !is_id_used_flag + } else { + !is_id_or_name_used_flag + } + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + // Updates are pruneable if the specific instance being assigned is never read + !is_id_used(state, lvalue.identifier) + } + InstructionValue::Debugger { .. } => { + // explicitly retain debugger statements + false + } + InstructionValue::CallExpression { callee, .. } => { + if env.output_mode == OutputMode::Ssr { + let callee_ty = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if let Some(hook_kind) = env.get_hook_kind_for_type(callee_ty).ok().flatten() { + match hook_kind { + HookKind::UseState | HookKind::UseReducer | HookKind::UseRef => { + return true; + } + _ => {} + } + } + } + false + } + InstructionValue::MethodCall { property, .. } => { + if env.output_mode == OutputMode::Ssr { + let callee_ty = + &env.types[env.identifiers[property.identifier.0 as usize].type_.0 as usize]; + if let Some(hook_kind) = env.get_hook_kind_for_type(callee_ty).ok().flatten() { + match hook_kind { + HookKind::UseState | HookKind::UseReducer | HookKind::UseRef => { + return true; + } + _ => {} + } + } + } + false + } + InstructionValue::Await { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::ComputedStore { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::PropertyStore { .. } + | InstructionValue::StoreGlobal { .. } => { + // Mutating instructions are not safe to prune + false + } + InstructionValue::NewExpression { .. } + | InstructionValue::UnsupportedNode { .. } + | InstructionValue::TaggedTemplateExpression { .. } => { + // Potentially safe to prune, but we conservatively keep them + false + } + InstructionValue::GetIterator { .. } + | InstructionValue::NextPropertyOf { .. } + | InstructionValue::IteratorNext { .. } => { + // Iterator operations are always used downstream + false + } + InstructionValue::LoadContext { .. } + | InstructionValue::DeclareContext { .. } + | InstructionValue::StoreContext { .. } => false, + InstructionValue::StartMemoize { .. } | InstructionValue::FinishMemoize { .. } => false, + InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::ComputedLoad { .. } + | InstructionValue::ObjectMethod { .. } + | InstructionValue::FunctionExpression { .. } + | InstructionValue::LoadLocal { .. } + | InstructionValue::JsxExpression { .. } + | InstructionValue::JsxFragment { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::PropertyLoad { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::TypeCastExpression { .. } + | InstructionValue::UnaryExpression { .. } => { + // Definitely safe to prune since they are read-only + true + } + } +} + +/// Check if the CFG has any back edges (indicating loops). +fn has_back_edge(func: &HirFunction) -> bool { + let mut visited: HashSet = HashSet::new(); + for (block_id, block) in &func.body.blocks { + for pred_id in &block.preds { + if !visited.contains(pred_id) { + return true; + } + } + visited.insert(*block_id); + } + false +} diff --git a/crates/react_compiler_optimization/src/drop_manual_memoization.rs b/crates/react_compiler_optimization/src/drop_manual_memoization.rs new file mode 100644 index 000000000000..bfbe18944c5b --- /dev/null +++ b/crates/react_compiler_optimization/src/drop_manual_memoization.rs @@ -0,0 +1,719 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Removes manual memoization using `useMemo` and `useCallback` APIs. +//! +//! For useMemo: replaces `Call useMemo(fn, deps)` with `Call fn()` +//! For useCallback: replaces `Call useCallback(fn, deps)` with `LoadLocal fn` +//! +//! When validation flags are set, inserts `StartMemoize`/`FinishMemoize` +//! markers. +//! +//! Analogous to TS `Inference/DropManualMemoization.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, ArrayElement, DependencyPathEntry, Effect, EvaluationOrder, + HirFunction, IdentifierId, IdentifierName, Instruction, InstructionId, InstructionValue, + ManualMemoDependency, ManualMemoDependencyRoot, Place, PlaceOrSpread, PropertyLiteral, + SourceLocation, +}; +use react_compiler_lowering::{create_temporary_place, mark_instruction_ids}; + +// ============================================================================= +// Types +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ManualMemoKind { + UseMemo, + UseCallback, +} + +#[derive(Debug, Clone)] +struct ManualMemoCallee { + kind: ManualMemoKind, + /// InstructionId of the LoadGlobal or PropertyLoad that loaded the callee. + load_instr_id: InstructionId, +} + +struct IdentifierSidemap { + /// Maps identifier id -> InstructionId of FunctionExpression instructions + functions: HashSet, + /// Maps identifier id -> ManualMemoCallee for useMemo/useCallback callees + manual_memos: HashMap, + /// Set of identifier ids that loaded 'React' global + react: HashSet, + /// Maps identifier id -> deps list info for array expressions + maybe_deps_lists: HashMap, + /// Maps identifier id -> ManualMemoDependency for dependency tracking + maybe_deps: HashMap, + /// Set of identifier ids that are results of optional chains + optionals: HashSet, +} + +#[derive(Debug, Clone)] +struct MaybeDepsListInfo { + loc: Option, + deps: Vec, +} + +struct ExtractedMemoArgs { + fn_place: Place, + deps_list: Option>, + deps_loc: Option, +} + +// ============================================================================= +// Main pass +// ============================================================================= + +/// Drop manual memoization (useMemo/useCallback calls), replacing them +/// with direct invocations/references. +pub fn drop_manual_memoization( + func: &mut HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + let is_validation_enabled = env.validate_preserve_existing_memoization_guarantees + || env.validate_no_set_state_in_render + || env.enable_preserve_existing_memoization_guarantees; + + let optionals = find_optional_places(func)?; + let mut sidemap = IdentifierSidemap { + functions: HashSet::new(), + manual_memos: HashMap::new(), + react: HashSet::new(), + maybe_deps: HashMap::new(), + maybe_deps_lists: HashMap::new(), + optionals, + }; + let mut next_manual_memo_id: u32 = 0; + + // Phase 1: + // - Overwrite manual memoization CallExpression/MethodCall + // - (if validation is enabled) collect manual memoization markers + // + // queued_inserts maps InstructionId -> new Instruction to insert after that + // instruction + let mut queued_inserts: HashMap = HashMap::new(); + + // Collect all block instruction lists up front to avoid borrowing func + // immutably while needing to mutate it + let all_block_instructions: Vec> = func + .body + .blocks + .values() + .map(|block| block.instructions.clone()) + .collect(); + + for block_instructions in &all_block_instructions { + for &instr_id in block_instructions { + let instr = &func.instructions[instr_id.0 as usize]; + + // Extract the identifier we need to look up, and whether it's a call/method + let lookup_id = match &instr.value { + InstructionValue::CallExpression { callee, .. } => Some(callee.identifier), + InstructionValue::MethodCall { property, .. } => Some(property.identifier), + _ => None, + }; + + let manual_memo = lookup_id.and_then(|id| sidemap.manual_memos.get(&id).cloned()); + + if let Some(manual_memo) = manual_memo { + process_manual_memo_call( + func, + env, + instr_id, + &manual_memo, + &mut sidemap, + is_validation_enabled, + &mut next_manual_memo_id, + &mut queued_inserts, + ); + } else { + collect_temporaries(func, env, instr_id, &mut sidemap); + } + } + } + + // Phase 2: Insert manual memoization markers as needed + if !queued_inserts.is_empty() { + let mut has_changes = false; + for block in func.body.blocks.values_mut() { + let mut next_instructions: Option> = None; + for i in 0..block.instructions.len() { + let instr_id = block.instructions[i]; + if let Some(insert_instr) = queued_inserts.remove(&instr_id) { + if next_instructions.is_none() { + next_instructions = Some(block.instructions[..i].to_vec()); + } + let ni = next_instructions.as_mut().unwrap(); + ni.push(instr_id); + // Add the new instruction to the flat table and get its InstructionId + let new_instr_id = InstructionId(func.instructions.len() as u32); + func.instructions.push(insert_instr); + ni.push(new_instr_id); + } else if let Some(ni) = next_instructions.as_mut() { + ni.push(instr_id); + } + } + if let Some(ni) = next_instructions { + block.instructions = ni; + has_changes = true; + } + } + + if has_changes { + mark_instruction_ids(&mut func.body, &mut func.instructions); + } + } + + Ok(()) +} + +// ============================================================================= +// Phase 1 helpers +// ============================================================================= + +#[allow(clippy::too_many_arguments)] +fn process_manual_memo_call( + func: &mut HirFunction, + env: &mut Environment, + instr_id: InstructionId, + manual_memo: &ManualMemoCallee, + sidemap: &mut IdentifierSidemap, + is_validation_enabled: bool, + next_manual_memo_id: &mut u32, + queued_inserts: &mut HashMap, +) { + let instr = &func.instructions[instr_id.0 as usize]; + + let memo_details = extract_manual_memoization_args(instr, manual_memo.kind, sidemap, env); + + let Some(memo_details) = memo_details else { + return; + }; + + let ExtractedMemoArgs { + fn_place, + deps_list, + deps_loc, + } = memo_details; + + let loc = func.instructions[instr_id.0 as usize].value.loc().cloned(); + + // Replace the instruction value with the memoization replacement + let replacement = get_manual_memoization_replacement(&fn_place, loc.clone(), manual_memo.kind); + func.instructions[instr_id.0 as usize].value = replacement; + + if is_validation_enabled { + // Bail out when we encounter manual memoization without inline function + // expressions + if !sidemap.functions.contains(&fn_place.identifier) { + let mut diag = CompilerDiagnostic::new( + ErrorCategory::UseMemo, + "Expected the first argument to be an inline function expression", + Some("Expected the first argument to be an inline function expression".to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: fn_place.loc.clone(), + message: Some( + "Expected the first argument to be an inline function expression".to_string(), + ), + identifier_name: None, + }); + // Match TS behavior: suggestions is [] (empty array), not null + diag.suggestions = Some(vec![]); + env.record_diagnostic(diag); + return; + } + + let memo_decl: Place = if manual_memo.kind == ManualMemoKind::UseMemo { + func.instructions[instr_id.0 as usize].lvalue.clone() + } else { + Place { + identifier: fn_place.identifier, + effect: Effect::Unknown, + reactive: false, + loc: fn_place.loc.clone(), + } + }; + + let manual_memo_id = *next_manual_memo_id; + *next_manual_memo_id += 1; + + let (start_marker, finish_marker) = make_manual_memoization_markers( + &fn_place, + env, + deps_list, + deps_loc, + &memo_decl, + manual_memo_id, + ); + + queued_inserts.insert(manual_memo.load_instr_id, start_marker); + queued_inserts.insert(instr_id, finish_marker); + } +} + +fn collect_temporaries( + func: &HirFunction, + env: &Environment, + instr_id: InstructionId, + sidemap: &mut IdentifierSidemap, +) { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + InstructionValue::FunctionExpression { .. } => { + sidemap.functions.insert(lvalue_id); + } + InstructionValue::LoadGlobal { binding, .. } => { + let name = binding.name(); + // DIVERGENCE: The TS version uses `env.getGlobalDeclaration()` + + // `getHookKindForType()` to resolve the binding through the type system + // and determine if it's useMemo/useCallback. Since the type/globals system + // is not yet ported, we match on the binding name directly. This means: + // - Custom hooks aliased to useMemo/useCallback won't be detected + // - Re-exports or renamed imports won't be detected + // - The behavior is equivalent for direct `useMemo`/`useCallback` imports and + // `React.useMemo`/`React.useCallback` member accesses (handled below) + // TODO: Use getGlobalDeclaration + getHookKindForType once the type system is + // ported. + if name == "useMemo" { + sidemap.manual_memos.insert( + lvalue_id, + ManualMemoCallee { + kind: ManualMemoKind::UseMemo, + load_instr_id: instr_id, + }, + ); + } else if name == "useCallback" { + sidemap.manual_memos.insert( + lvalue_id, + ManualMemoCallee { + kind: ManualMemoKind::UseCallback, + load_instr_id: instr_id, + }, + ); + } else if name == "React" { + sidemap.react.insert(lvalue_id); + } + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + if sidemap.react.contains(&object.identifier) { + if let PropertyLiteral::String(prop_name) = property { + if prop_name == "useMemo" { + sidemap.manual_memos.insert( + lvalue_id, + ManualMemoCallee { + kind: ManualMemoKind::UseMemo, + load_instr_id: instr_id, + }, + ); + } else if prop_name == "useCallback" { + sidemap.manual_memos.insert( + lvalue_id, + ManualMemoCallee { + kind: ManualMemoKind::UseCallback, + load_instr_id: instr_id, + }, + ); + } + } + } + } + InstructionValue::ArrayExpression { elements, .. } => { + // Check if all elements are Identifier (Place) - no spreads or holes + let all_places: Option> = elements + .iter() + .map(|e| match e { + ArrayElement::Place(p) => Some(p.clone()), + _ => None, + }) + .collect(); + + if let Some(deps) = all_places { + sidemap.maybe_deps_lists.insert( + lvalue_id, + MaybeDepsListInfo { + loc: instr.value.loc().cloned(), + deps, + }, + ); + } + } + _ => {} + } + + let is_optional = sidemap.optionals.contains(&lvalue_id); + let maybe_dep = + collect_maybe_memo_dependencies(&instr.value, &sidemap.maybe_deps, is_optional, env); + if let Some(dep) = maybe_dep { + // For StoreLocal, also insert under the StoreLocal's lvalue place identifier, + // matching the TS behavior where collectMaybeMemoDependencies inserts into + // maybeDeps directly for StoreLocal's target variable. + if let InstructionValue::StoreLocal { lvalue, .. } = &instr.value { + sidemap + .maybe_deps + .insert(lvalue.place.identifier, dep.clone()); + } + sidemap.maybe_deps.insert(lvalue_id, dep); + } +} + +// ============================================================================= +// collectMaybeMemoDependencies +// ============================================================================= + +/// Collect loads from named variables and property reads into `maybe_deps`. +/// Returns the variable + property reads represented by the instruction value. +pub fn collect_maybe_memo_dependencies( + value: &InstructionValue, + maybe_deps: &HashMap, + optional: bool, + env: &Environment, +) -> Option { + match value { + InstructionValue::LoadGlobal { binding, loc, .. } => Some(ManualMemoDependency { + root: ManualMemoDependencyRoot::Global { + identifier_name: binding.name().to_string(), + }, + path: vec![], + loc: loc.clone(), + }), + InstructionValue::PropertyLoad { + object, + property, + loc, + .. + } => { + if let Some(object_dep) = maybe_deps.get(&object.identifier) { + Some(ManualMemoDependency { + root: object_dep.root.clone(), + path: { + let mut path = object_dep.path.clone(); + path.push(DependencyPathEntry { + property: property.clone(), + optional, + loc: loc.clone(), + }); + path + }, + loc: loc.clone(), + }) + } else { + None + } + } + InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => { + if let Some(source) = maybe_deps.get(&place.identifier) { + Some(source.clone()) + } else if matches!( + &env.identifiers[place.identifier.0 as usize].name, + Some(IdentifierName::Named(_)) + ) { + Some(ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: place.clone(), + constant: false, + }, + path: vec![], + loc: place.loc.clone(), + }) + } else { + None + } + } + InstructionValue::StoreLocal { + lvalue, value: val, .. + } => { + // Value blocks rely on StoreLocal to populate their return value. + // We need to track these as optional property chains are valid in + // source depslists + let lvalue_id = lvalue.place.identifier; + let rvalue_id = val.identifier; + if let Some(aliased) = maybe_deps.get(&rvalue_id) { + let lvalue_name = &env.identifiers[lvalue_id.0 as usize].name; + if !matches!(lvalue_name, Some(IdentifierName::Named(_))) { + // Note: we can't insert into maybe_deps here since we only have + // a shared reference. The caller handles insertion. + return Some(aliased.clone()); + } + } + None + } + _ => None, + } +} + +// ============================================================================= +// Replacement helpers +// ============================================================================= + +fn get_manual_memoization_replacement( + fn_place: &Place, + loc: Option, + kind: ManualMemoKind, +) -> InstructionValue { + if kind == ManualMemoKind::UseMemo { + // Replace with Call fn() - invoke the memo function directly + InstructionValue::CallExpression { + callee: fn_place.clone(), + args: vec![], + loc, + } + } else { + // Replace with LoadLocal fn - just reference the function + InstructionValue::LoadLocal { + place: Place { + identifier: fn_place.identifier, + effect: Effect::Unknown, + reactive: false, + loc: loc.clone(), + }, + loc, + } + } +} + +fn make_manual_memoization_markers( + fn_expr: &Place, + env: &mut Environment, + deps_list: Option>, + deps_loc: Option, + memo_decl: &Place, + manual_memo_id: u32, +) -> (Instruction, Instruction) { + let start = Instruction { + id: EvaluationOrder(0), + lvalue: create_temporary_place(env, fn_expr.loc.clone()), + value: InstructionValue::StartMemoize { + manual_memo_id, + deps: deps_list, + deps_loc: Some(deps_loc), + has_invalid_deps: false, + loc: fn_expr.loc.clone(), + }, + loc: fn_expr.loc.clone(), + effects: None, + }; + let finish = Instruction { + id: EvaluationOrder(0), + lvalue: create_temporary_place(env, fn_expr.loc.clone()), + value: InstructionValue::FinishMemoize { + manual_memo_id, + decl: memo_decl.clone(), + pruned: false, + loc: fn_expr.loc.clone(), + }, + loc: fn_expr.loc.clone(), + effects: None, + }; + (start, finish) +} + +fn extract_manual_memoization_args( + instr: &Instruction, + kind: ManualMemoKind, + sidemap: &IdentifierSidemap, + env: &mut Environment, +) -> Option { + let args: &[PlaceOrSpread] = match &instr.value { + InstructionValue::CallExpression { args, .. } => args, + InstructionValue::MethodCall { args, .. } => args, + _ => return None, + }; + + let kind_name = match kind { + ManualMemoKind::UseMemo => "useMemo", + ManualMemoKind::UseCallback => "useCallback", + }; + + // Get the first arg (fn) + let fn_place = match args.first() { + Some(PlaceOrSpread::Place(p)) => p.clone(), + _ => { + let loc = instr.value.loc().cloned(); + env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::UseMemo, + format!("Expected a callback function to be passed to {kind_name}"), + Some(if kind == ManualMemoKind::UseCallback { + "The first argument to useCallback() must be a function to cache" + .to_string() + } else { + "The first argument to useMemo() must be a function that calculates a \ + result to cache" + .to_string() + }), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some(if kind == ManualMemoKind::UseCallback { + "Expected a callback function".to_string() + } else { + "Expected a memoization function".to_string() + }), + identifier_name: None, + }), + ); + return None; + } + }; + + // Get the second arg (deps list), if present + let deps_list_place = args.get(1); + if deps_list_place.is_none() { + return Some(ExtractedMemoArgs { + fn_place, + deps_list: None, + deps_loc: None, + }); + } + + let deps_list_id = match deps_list_place { + Some(PlaceOrSpread::Place(p)) => Some(p.identifier), + _ => None, + }; + + let maybe_deps_list = deps_list_id.and_then(|id| sidemap.maybe_deps_lists.get(&id)); + + if maybe_deps_list.is_none() { + let loc = match deps_list_place { + Some(PlaceOrSpread::Place(p)) => p.loc.clone(), + _ => instr.loc.clone(), + }; + env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::UseMemo, + format!("Expected the dependency list for {kind_name} to be an array literal"), + Some(format!( + "Expected the dependency list for {kind_name} to be an array literal" + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some(format!( + "Expected the dependency list for {kind_name} to be an array literal" + )), + identifier_name: None, + }), + ); + return None; + } + + let deps_info = maybe_deps_list.unwrap(); + let mut deps_list: Vec = Vec::new(); + for dep in &deps_info.deps { + let maybe_dep = sidemap.maybe_deps.get(&dep.identifier); + if let Some(d) = maybe_dep { + deps_list.push(d.clone()); + } else { + env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::UseMemo, + "Expected the dependency list to be an array of simple expressions (e.g. `x`, \ + `x.y.z`, `x?.y?.z`)", + Some( + "Expected the dependency list to be an array of simple expressions (e.g. \ + `x`, `x.y.z`, `x?.y?.z`)" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: dep.loc.clone(), + message: Some( + "Expected the dependency list to be an array of simple expressions (e.g. \ + `x`, `x.y.z`, `x?.y?.z`)" + .to_string(), + ), + identifier_name: None, + }), + ); + } + } + + Some(ExtractedMemoArgs { + fn_place, + deps_list: Some(deps_list), + deps_loc: deps_info.loc.clone(), + }) +} + +// ============================================================================= +// findOptionalPlaces +// ============================================================================= + +fn find_optional_places(func: &HirFunction) -> Result, CompilerDiagnostic> { + use react_compiler_hir::Terminal; + + let mut optionals = HashSet::new(); + for block in func.body.blocks.values() { + if let Terminal::Optional { + optional: true, + test, + fallthrough, + .. + } = &block.terminal + { + let optional_fallthrough = *fallthrough; + let mut test_block_id = *test; + loop { + let test_block = &func.body.blocks[&test_block_id]; + match &test_block.terminal { + Terminal::Branch { + consequent, + fallthrough, + .. + } => { + if *fallthrough == optional_fallthrough { + // Found it + let consequent_block = &func.body.blocks[consequent]; + if let Some(&last_instr_id) = consequent_block.instructions.last() { + let last_instr = &func.instructions[last_instr_id.0 as usize]; + if let InstructionValue::StoreLocal { value, .. } = + &last_instr.value + { + optionals.insert(value.identifier); + } + } + break; + } else { + test_block_id = *fallthrough; + } + } + Terminal::Optional { fallthrough, .. } + | Terminal::Logical { fallthrough, .. } + | Terminal::Sequence { fallthrough, .. } + | Terminal::Ternary { fallthrough, .. } => { + test_block_id = *fallthrough; + } + Terminal::MaybeThrow { continuation, .. } => { + test_block_id = *continuation; + } + other => { + // Invariant: unexpected terminal in optional + // In TS this throws CompilerError.invariant + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "Unexpected terminal kind in optional: {:?}", + std::mem::discriminant(other) + ), + None, + )); + } + } + } + } + } + Ok(optionals) +} diff --git a/crates/react_compiler_optimization/src/inline_iifes.rs b/crates/react_compiler_optimization/src/inline_iifes.rs new file mode 100644 index 000000000000..fe3531f9c135 --- /dev/null +++ b/crates/react_compiler_optimization/src/inline_iifes.rs @@ -0,0 +1,417 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Inlines immediately invoked function expressions (IIFEs) to allow more +//! fine-grained memoization of the values they produce. +//! +//! Example: +//! ```text +//! const x = (() => { +//! const x = []; +//! x.push(foo()); +//! return x; +//! })(); +//! +//! => +//! +//! bb0: +//! // placeholder for the result, all return statements will assign here +//! let t0; +//! // Label allows using a goto (break) to exit out of the body +//! Label block=bb1 fallthrough=bb2 +//! bb1: +//! // code within the function expression +//! const x0 = []; +//! x0.push(foo()); +//! // return is replaced by assignment to the result variable... +//! t0 = x0; +//! // ...and a goto to the code after the function expression invocation +//! Goto bb2 +//! bb2: +//! // code after the IIFE call +//! const x = t0; +//! ``` +//! +//! If the inlined function has only one return, we avoid the labeled block +//! and fully inline the code. The original return is replaced with an +//! assignment to the IIFE's call expression lvalue. +//! +//! Analogous to TS `Inference/InlineImmediatelyInvokedFunctionExpressions.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + environment::Environment, visitors, BasicBlock, BlockId, BlockKind, EvaluationOrder, + FunctionId, GotoVariant, HirFunction, IdentifierId, IdentifierName, Instruction, InstructionId, + InstructionKind, InstructionValue, LValue, Place, Terminal, GENERATED_SOURCE, +}; +use react_compiler_lowering::{ + create_temporary_place, get_reverse_postordered_blocks, mark_instruction_ids, mark_predecessors, +}; + +use crate::merge_consecutive_blocks::merge_consecutive_blocks; + +/// Inline immediately invoked function expressions into the enclosing +/// function's control flow graph. +pub fn inline_immediately_invoked_function_expressions( + func: &mut HirFunction, + env: &mut Environment, +) { + // Track all function expressions that are assigned to a temporary + let mut functions: HashMap = HashMap::new(); + // Functions that are inlined (by identifier id of the callee) + let mut inlined_functions: HashSet = HashSet::new(); + + // Iterate the *existing* blocks from the outer component to find IIFEs + // and inline them. During iteration we will modify `func` (by inlining the CFG + // of IIFEs) so we explicitly copy references to just the original + // function's block IDs first. As blocks are split to make room for IIFE calls, + // the split portions of the blocks will be added to this queue. + let mut queue: Vec = func.body.blocks.keys().copied().collect(); + let mut queue_idx = 0; + + 'queue: while queue_idx < queue.len() { + let block_id = queue[queue_idx]; + queue_idx += 1; + + let block = match func.body.blocks.get(&block_id) { + Some(b) => b, + None => continue, + }; + + // We can't handle labels inside expressions yet, so we don't inline IIFEs + // if they are in an expression block. + if !is_statement_block_kind(block.kind) { + continue; + } + + let num_instructions = block.instructions.len(); + for ii in 0..num_instructions { + let instr_id = func.body.blocks[&block_id].instructions[ii]; + let instr = &func.instructions[instr_id.0 as usize]; + + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } => { + let identifier_id = instr.lvalue.identifier; + if env.identifiers[identifier_id.0 as usize].name.is_none() { + functions.insert(identifier_id, lowered_func.func); + } + continue; + } + InstructionValue::CallExpression { callee, args, .. } => { + if !args.is_empty() { + // We don't support inlining when there are arguments + continue; + } + + let callee_id = callee.identifier; + let inner_func_id = match functions.get(&callee_id) { + Some(id) => *id, + None => continue, // Not invoking a local function expression + }; + + let inner_func = &env.functions[inner_func_id.0 as usize]; + if !inner_func.params.is_empty() || inner_func.is_async || inner_func.generator + { + // Can't inline functions with params, or async/generator functions + continue; + } + + // We know this function is used for an IIFE and can prune it later + inlined_functions.insert(callee_id); + + // Capture the lvalue from the call instruction + let call_lvalue = func.instructions[instr_id.0 as usize].lvalue.clone(); + let block_terminal_id = func.body.blocks[&block_id].terminal.evaluation_order(); + let block_terminal_loc = func.body.blocks[&block_id].terminal.loc().cloned(); + let block_kind = func.body.blocks[&block_id].kind; + + // Create a new block which will contain code following the IIFE call + let continuation_block_id = env.next_block_id(); + let continuation_instructions: Vec = + func.body.blocks[&block_id].instructions[ii + 1..].to_vec(); + let continuation_terminal = func.body.blocks[&block_id].terminal.clone(); + let continuation_block = BasicBlock { + id: continuation_block_id, + instructions: continuation_instructions, + kind: block_kind, + phis: Vec::new(), + preds: indexmap::IndexSet::new(), + terminal: continuation_terminal, + }; + func.body + .blocks + .insert(continuation_block_id, continuation_block); + + // Trim the original block to contain instructions up to (but not including) + // the IIFE + func.body + .blocks + .get_mut(&block_id) + .unwrap() + .instructions + .truncate(ii); + + let has_single_return = + has_single_exit_return_terminal(&env.functions[inner_func_id.0 as usize]); + let inner_entry = env.functions[inner_func_id.0 as usize].body.entry; + + if has_single_return { + // Single-return path: simple goto replacement + func.body.blocks.get_mut(&block_id).unwrap().terminal = Terminal::Goto { + block: inner_entry, + id: block_terminal_id, + loc: block_terminal_loc, + variant: GotoVariant::Break, + }; + + // Take blocks and instructions from inner function + let inner_func = &mut env.functions[inner_func_id.0 as usize]; + let inner_blocks: Vec<(BlockId, BasicBlock)> = + inner_func.body.blocks.drain(..).collect(); + let inner_instructions: Vec = + inner_func.instructions.drain(..).collect(); + + // Append inner instructions first, then remap block instruction IDs + let instr_offset = func.instructions.len() as u32; + func.instructions.extend(inner_instructions); + + for (_, mut inner_block) in inner_blocks { + // Remap instruction IDs in the block + for iid in &mut inner_block.instructions { + *iid = InstructionId(iid.0 + instr_offset); + } + inner_block.preds.clear(); + + if let Terminal::Return { + value, + id: ret_id, + loc: ret_loc, + .. + } = &inner_block.terminal + { + // Replace return with LoadLocal + goto + let load_instr = Instruction { + id: EvaluationOrder(0), + loc: ret_loc.clone(), + lvalue: call_lvalue.clone(), + value: InstructionValue::LoadLocal { + place: value.clone(), + loc: ret_loc.clone(), + }, + effects: None, + }; + let load_instr_id = InstructionId(func.instructions.len() as u32); + func.instructions.push(load_instr); + inner_block.instructions.push(load_instr_id); + + let ret_id = *ret_id; + let ret_loc = ret_loc.clone(); + inner_block.terminal = Terminal::Goto { + block: continuation_block_id, + id: ret_id, + loc: ret_loc, + variant: GotoVariant::Break, + }; + } + + func.body.blocks.insert(inner_block.id, inner_block); + } + } else { + // Multi-return path: uses LabelTerminal + let result = call_lvalue.clone(); + + // Set block terminal to Label + func.body.blocks.get_mut(&block_id).unwrap().terminal = Terminal::Label { + block: inner_entry, + id: EvaluationOrder(0), + fallthrough: continuation_block_id, + loc: block_terminal_loc, + }; + + // Declare the IIFE temporary + declare_temporary(env, func, block_id, &result); + + // Promote the temporary with a name as we require this to persist + let identifier_id = result.identifier; + if env.identifiers[identifier_id.0 as usize].name.is_none() { + promote_temporary(env, identifier_id); + } + + // Take blocks and instructions from inner function + let inner_func = &mut env.functions[inner_func_id.0 as usize]; + let inner_blocks: Vec<(BlockId, BasicBlock)> = + inner_func.body.blocks.drain(..).collect(); + let inner_instructions: Vec = + inner_func.instructions.drain(..).collect(); + + // Append inner instructions first, then remap block instruction IDs + let instr_offset = func.instructions.len() as u32; + func.instructions.extend(inner_instructions); + + for (_, mut inner_block) in inner_blocks { + for iid in &mut inner_block.instructions { + *iid = InstructionId(iid.0 + instr_offset); + } + inner_block.preds.clear(); + + // Rewrite return terminals to StoreLocal + goto + if matches!(inner_block.terminal, Terminal::Return { .. }) { + rewrite_block( + env, + &mut func.instructions, + &mut inner_block, + continuation_block_id, + &result, + ); + } + + func.body.blocks.insert(inner_block.id, inner_block); + } + } + + // Ensure we visit the continuation block, since there may have been + // sequential IIFEs that need to be visited. + queue.push(continuation_block_id); + continue 'queue; + } + _ => { + // Any other use of a function expression means it isn't an IIFE + for id in visitors::each_instruction_value_operand_ids(&instr.value, env) { + functions.remove(&id); + } + } + } + } + } + + if !inlined_functions.is_empty() { + // Remove instructions that define lambdas which we inlined + for block in func.body.blocks.values_mut() { + block.instructions.retain(|instr_id| { + let instr = &func.instructions[instr_id.0 as usize]; + !inlined_functions.contains(&instr.lvalue.identifier) + }); + } + + // If terminals have changed then blocks may have become newly unreachable. + // Re-run minification of the graph (incl reordering instruction ids). + func.body.blocks = get_reverse_postordered_blocks(&func.body, &func.instructions); + mark_instruction_ids(&mut func.body, &mut func.instructions); + mark_predecessors(&mut func.body); + merge_consecutive_blocks(func, &mut env.functions); + } +} + +/// Returns true for "block" and "catch" block kinds which correspond to +/// statements in the source. +fn is_statement_block_kind(kind: BlockKind) -> bool { + matches!(kind, BlockKind::Block | BlockKind::Catch) +} + +/// Returns true if the function has a single exit terminal (throw/return) which +/// is a return. +fn has_single_exit_return_terminal(func: &HirFunction) -> bool { + let mut has_return = false; + let mut exit_count = 0; + for block in func.body.blocks.values() { + match &block.terminal { + Terminal::Return { .. } => { + has_return = true; + exit_count += 1; + } + Terminal::Throw { .. } => { + exit_count += 1; + } + _ => {} + } + } + exit_count == 1 && has_return +} + +/// Rewrites the block so that all `return` terminals are replaced: +/// * Add a StoreLocal = +/// * Replace the terminal with a Goto to +fn rewrite_block( + env: &mut Environment, + instructions: &mut Vec, + block: &mut BasicBlock, + return_target: BlockId, + return_value: &Place, +) { + if let Terminal::Return { + value, + loc: ret_loc, + .. + } = &block.terminal + { + let store_lvalue = create_temporary_place(env, ret_loc.clone()); + let store_instr = Instruction { + id: EvaluationOrder(0), + loc: ret_loc.clone(), + lvalue: store_lvalue, + value: InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: return_value.clone(), + }, + value: value.clone(), + type_annotation: None, + loc: ret_loc.clone(), + }, + effects: None, + }; + let store_instr_id = InstructionId(instructions.len() as u32); + instructions.push(store_instr); + block.instructions.push(store_instr_id); + + let ret_loc = ret_loc.clone(); + block.terminal = Terminal::Goto { + block: return_target, + id: EvaluationOrder(0), + variant: GotoVariant::Break, + loc: ret_loc, + }; + } +} + +/// Emits a DeclareLocal instruction for the result temporary. +fn declare_temporary( + env: &mut Environment, + func: &mut HirFunction, + block_id: BlockId, + result: &Place, +) { + let declare_lvalue = create_temporary_place(env, result.loc.clone()); + let declare_instr = Instruction { + id: EvaluationOrder(0), + loc: GENERATED_SOURCE, + lvalue: declare_lvalue, + value: InstructionValue::DeclareLocal { + lvalue: LValue { + place: result.clone(), + kind: InstructionKind::Let, + }, + type_annotation: None, + loc: result.loc.clone(), + }, + effects: None, + }; + let instr_id = InstructionId(func.instructions.len() as u32); + func.instructions.push(declare_instr); + func.body + .blocks + .get_mut(&block_id) + .unwrap() + .instructions + .push(instr_id); +} + +/// Promote a temporary identifier to a named identifier. +fn promote_temporary(env: &mut Environment, identifier_id: IdentifierId) { + let decl_id = env.identifiers[identifier_id.0 as usize].declaration_id; + env.identifiers[identifier_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); +} diff --git a/crates/react_compiler_optimization/src/lib.rs b/crates/react_compiler_optimization/src/lib.rs new file mode 100644 index 000000000000..0fd63584ac31 --- /dev/null +++ b/crates/react_compiler_optimization/src/lib.rs @@ -0,0 +1,28 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod constant_propagation; +pub mod dead_code_elimination; +pub mod drop_manual_memoization; +pub mod inline_iifes; +pub mod merge_consecutive_blocks; +pub mod name_anonymous_functions; +pub mod optimize_for_ssr; +pub mod optimize_props_method_calls; +pub mod outline_functions; +pub mod outline_jsx; +pub mod prune_maybe_throws; +pub mod prune_unused_labels_hir; + +pub use constant_propagation::constant_propagation; +pub use dead_code_elimination::dead_code_elimination; +pub use drop_manual_memoization::drop_manual_memoization; +pub use inline_iifes::inline_immediately_invoked_function_expressions; +pub use name_anonymous_functions::name_anonymous_functions; +pub use optimize_for_ssr::optimize_for_ssr; +pub use optimize_props_method_calls::optimize_props_method_calls; +pub use outline_functions::outline_functions; +pub use outline_jsx::outline_jsx; +pub use prune_maybe_throws::prune_maybe_throws; +pub use prune_unused_labels_hir::prune_unused_labels_hir; diff --git a/crates/react_compiler_optimization/src/merge_consecutive_blocks.rs b/crates/react_compiler_optimization/src/merge_consecutive_blocks.rs new file mode 100644 index 000000000000..820971283409 --- /dev/null +++ b/crates/react_compiler_optimization/src/merge_consecutive_blocks.rs @@ -0,0 +1,214 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Merges sequences of blocks that will always execute consecutively — +//! i.e., where the predecessor always transfers control to the successor +//! (ends in a goto) and where the predecessor is the only predecessor +//! for that successor (no other way to reach the successor). +//! +//! Value/loop blocks are left alone because they cannot be merged without +//! breaking the structure of the high-level terminals that reference them. +//! +//! Analogous to TS `HIR/MergeConsecutiveBlocks.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + visitors, AliasingEffect, BlockId, BlockKind, Effect, HirFunction, Instruction, InstructionId, + InstructionValue, Place, Terminal, GENERATED_SOURCE, +}; +use react_compiler_lowering::mark_predecessors; +use react_compiler_ssa::enter_ssa::placeholder_function; + +/// Merge consecutive blocks in the function's CFG, including inner functions. +pub fn merge_consecutive_blocks(func: &mut HirFunction, functions: &mut [HirFunction]) { + // Collect inner function IDs for recursive processing + let inner_func_ids: Vec = func + .body + .blocks + .values() + .flat_map(|block| block.instructions.iter()) + .filter_map(|instr_id| { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + Some(lowered_func.func.0 as usize) + } + _ => None, + } + }) + .collect(); + + // Recursively merge consecutive blocks in inner functions + for func_id in inner_func_ids { + // Use std::mem::replace to temporarily take the inner function out, + // process it, then put it back (standard borrow checker workaround) + let mut inner_func = std::mem::replace(&mut functions[func_id], placeholder_function()); + merge_consecutive_blocks(&mut inner_func, functions); + functions[func_id] = inner_func; + } + + // Build fallthrough set + let mut fallthrough_blocks: HashSet = HashSet::new(); + for block in func.body.blocks.values() { + if let Some(ft) = visitors::terminal_fallthrough(&block.terminal) { + fallthrough_blocks.insert(ft); + } + } + + let mut merged = MergedBlocks::new(); + + // Collect block IDs for iteration (since we modify during iteration) + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + for block_id in &block_ids { + let block = match func.body.blocks.get(block_id) { + Some(b) => b, + None => continue, // already removed + }; + + if block.preds.len() != 1 + || block.kind != BlockKind::Block + || fallthrough_blocks.contains(block_id) + { + continue; + } + + let original_pred_id = *block.preds.iter().next().unwrap(); + let pred_id = merged.get(original_pred_id); + + // Check predecessor exists and ends in goto with block kind + let pred_is_mergeable = func + .body + .blocks + .get(&pred_id) + .map(|p| matches!(p.terminal, Terminal::Goto { .. }) && p.kind == BlockKind::Block) + .unwrap_or(false); + + if !pred_is_mergeable { + continue; + } + + // Get evaluation order from predecessor's terminal (for phi instructions) + let eval_order = func.body.blocks[&pred_id].terminal.evaluation_order(); + + // Collect phi data from the block being merged + let phis: Vec<_> = block + .phis + .iter() + .map(|phi| { + assert_eq!( + phi.operands.len(), + 1, + "Found a block with a single predecessor but where a phi has multiple ({}) \ + operands", + phi.operands.len() + ); + let operand = phi.operands.values().next().unwrap().clone(); + (phi.place.identifier, operand) + }) + .collect(); + let block_instr_ids = block.instructions.clone(); + let block_terminal = block.terminal.clone(); + + // Create phi instructions and add to instruction table + let mut new_instr_ids = Vec::new(); + for (identifier, operand) in phis { + let lvalue = Place { + identifier, + effect: Effect::ConditionallyMutate, + reactive: false, + loc: GENERATED_SOURCE, + }; + let instr = Instruction { + id: eval_order, + lvalue: lvalue.clone(), + value: InstructionValue::LoadLocal { + place: operand.clone(), + loc: GENERATED_SOURCE, + }, + loc: GENERATED_SOURCE, + effects: Some(vec![AliasingEffect::Alias { + from: operand, + into: lvalue, + }]), + }; + let instr_id = InstructionId(func.instructions.len() as u32); + func.instructions.push(instr); + new_instr_ids.push(instr_id); + } + + // Apply merge to predecessor + let pred = func.body.blocks.get_mut(&pred_id).unwrap(); + pred.instructions.extend(new_instr_ids); + pred.instructions.extend(block_instr_ids); + pred.terminal = block_terminal; + + // Record merge and remove block + merged.merge(*block_id, pred_id); + func.body.blocks.shift_remove(block_id); + } + + // Update phi operands for merged blocks + for block in func.body.blocks.values_mut() { + for phi in &mut block.phis { + let updates: Vec<_> = phi + .operands + .iter() + .filter_map(|(pred_id, operand)| { + let mapped = merged.get(*pred_id); + if mapped != *pred_id { + Some((*pred_id, mapped, operand.clone())) + } else { + None + } + }) + .collect(); + for (old_id, new_id, operand) in updates { + phi.operands.shift_remove(&old_id); + phi.operands.insert(new_id, operand); + } + } + } + + mark_predecessors(&mut func.body); + + // Update terminal successors (including fallthroughs) for merged blocks + for block in func.body.blocks.values_mut() { + visitors::map_terminal_successors(&mut block.terminal, &mut |block_id| { + merged.get(block_id) + }); + } +} + +/// Tracks which blocks have been merged and into which target. +struct MergedBlocks { + map: HashMap, +} + +impl MergedBlocks { + fn new() -> Self { + Self { + map: HashMap::new(), + } + } + + /// Record that `block` was merged into `into`. + fn merge(&mut self, block: BlockId, into: BlockId) { + let target = self.get(into); + self.map.insert(block, target); + } + + /// Get the id of the block that `block` has been merged into. + /// Transitive: if A merged into B which merged into C, get(A) returns C. + fn get(&self, block: BlockId) -> BlockId { + let mut current = block; + while let Some(&target) = self.map.get(¤t) { + current = target; + } + current + } +} diff --git a/crates/react_compiler_optimization/src/name_anonymous_functions.rs b/crates/react_compiler_optimization/src/name_anonymous_functions.rs new file mode 100644 index 000000000000..79a7427915cf --- /dev/null +++ b/crates/react_compiler_optimization/src/name_anonymous_functions.rs @@ -0,0 +1,317 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Port of NameAnonymousFunctions from TypeScript. +//! +//! Generates descriptive names for anonymous function expressions based on +//! how they are used (assigned to variables, passed as arguments to +//! hooks/functions, used as JSX props, etc.). These names appear in React +//! DevTools and error stacks. +//! +//! Conditional on `env.config.enable_name_anonymous_functions`. + +use std::collections::HashMap; + +use react_compiler_hir::{ + environment::Environment, object_shape::HookKind, FunctionId, HirFunction, IdentifierId, + IdentifierName, Instruction, InstructionValue, JsxAttribute, JsxTag, PlaceOrSpread, +}; + +/// Assign generated names to anonymous function expressions. +/// +/// Ported from TS `nameAnonymousFunctions` in +/// `Transform/NameAnonymousFunctions.ts`. +pub fn name_anonymous_functions(func: &mut HirFunction, env: &mut Environment) { + let fn_id = match &func.id { + Some(id) => id.clone(), + None => return, + }; + + let nodes = name_anonymous_functions_impl(func, env); + + fn visit(node: &Node, prefix: &str, updates: &mut Vec<(FunctionId, String)>) { + if node.generated_name.is_some() && node.existing_name_hint.is_none() { + // Only add the prefix to anonymous functions regardless of nesting depth + let name = format!("{}{}]", prefix, node.generated_name.as_ref().unwrap()); + updates.push((node.function_id, name)); + } + // Whether or not we generated a name for the function at this node, + // traverse into its nested functions to assign them names + let fallback; + let label = if let Some(ref gen_name) = node.generated_name { + gen_name.as_str() + } else if let Some(ref existing) = node.fn_name { + existing.as_str() + } else { + fallback = ""; + fallback + }; + let next_prefix = format!("{}{} > ", prefix, label); + for inner in &node.inner { + visit(inner, &next_prefix, updates); + } + } + + let mut updates: Vec<(FunctionId, String)> = Vec::new(); + let prefix = format!("{}[", fn_id); + for node in &nodes { + visit(node, &prefix, &mut updates); + } + + if updates.is_empty() { + return; + } + let update_map: HashMap = + updates.iter().map(|(fid, name)| (*fid, name)).collect(); + + // Apply name updates to the inner HirFunction in the arena + for (function_id, name) in &updates { + env.functions[function_id.0 as usize].name_hint = Some(name.clone()); + } + + // Update name_hint on FunctionExpression instruction values in the outer + // function + apply_name_hints_to_instructions(&mut func.instructions, &update_map); + + // Update name_hint on FunctionExpression instruction values in all arena + // functions + for i in 0..env.functions.len() { + // We need to temporarily take the instructions to avoid borrow issues + let mut instructions = std::mem::take(&mut env.functions[i].instructions); + apply_name_hints_to_instructions(&mut instructions, &update_map); + env.functions[i].instructions = instructions; + } +} + +/// Apply name hints to FunctionExpression instruction values. +fn apply_name_hints_to_instructions( + instructions: &mut [Instruction], + update_map: &HashMap, +) { + for instr in instructions.iter_mut() { + if let InstructionValue::FunctionExpression { + lowered_func, + name_hint, + .. + } = &mut instr.value + { + if let Some(new_name) = update_map.get(&lowered_func.func) { + *name_hint = Some((*new_name).clone()); + } + } + } +} + +struct Node { + /// The FunctionId for the inner function (via lowered_func.func) + function_id: FunctionId, + /// The generated name for this anonymous function (set based on usage + /// context) + generated_name: Option, + /// The existing `name` on the FunctionExpression (non-anonymous functions + /// have this) + fn_name: Option, + /// Whether the inner HirFunction already has a name_hint + existing_name_hint: Option, + /// Nested function nodes + inner: Vec, +} + +fn name_anonymous_functions_impl(func: &HirFunction, env: &Environment) -> Vec { + // Functions that we track to generate names for + let mut functions: HashMap = HashMap::new(); + // Tracks temporaries that read from variables/globals/properties + let mut names: HashMap = HashMap::new(); + // Tracks all function nodes + let mut nodes: Vec = Vec::new(); + + for block in func.body.blocks.values() { + for instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + match &instr.value { + InstructionValue::LoadGlobal { binding, .. } => { + names.insert(lvalue_id, binding.name().to_string()); + } + InstructionValue::LoadContext { place, .. } + | InstructionValue::LoadLocal { place, .. } => { + let ident = &env.identifiers[place.identifier.0 as usize]; + if let Some(IdentifierName::Named(ref name)) = ident.name { + names.insert(lvalue_id, name.clone()); + } + // If the loaded place was tracked as a function, propagate + if let Some(&node_idx) = functions.get(&place.identifier) { + functions.insert(lvalue_id, node_idx); + } + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + if let Some(object_name) = names.get(&object.identifier) { + names.insert(lvalue_id, format!("{}.{}", object_name, property)); + } + } + InstructionValue::FunctionExpression { + name, lowered_func, .. + } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + let inner = name_anonymous_functions_impl(inner_func, env); + let node = Node { + function_id: lowered_func.func, + generated_name: None, + fn_name: name.clone(), + existing_name_hint: inner_func.name_hint.clone(), + inner, + }; + let idx = nodes.len(); + nodes.push(node); + if name.is_none() { + // Only generate names for anonymous functions + functions.insert(lvalue_id, idx); + } + } + InstructionValue::StoreContext { + lvalue: store_lvalue, + value, + .. + } + | InstructionValue::StoreLocal { + lvalue: store_lvalue, + value, + .. + } => { + if let Some(&node_idx) = functions.get(&value.identifier) { + let node = &mut nodes[node_idx]; + let var_ident = &env.identifiers[store_lvalue.place.identifier.0 as usize]; + if node.generated_name.is_none() { + if let Some(IdentifierName::Named(ref var_name)) = var_ident.name { + node.generated_name = Some(var_name.clone()); + functions.remove(&value.identifier); + } + } + } + } + InstructionValue::CallExpression { callee, args, .. } => { + handle_call( + env, + func, + callee.identifier, + args, + &mut functions, + &names, + &mut nodes, + ); + } + InstructionValue::MethodCall { property, args, .. } => { + handle_call( + env, + func, + property.identifier, + args, + &mut functions, + &names, + &mut nodes, + ); + } + InstructionValue::JsxExpression { tag, props, .. } => { + for attr in props { + match attr { + JsxAttribute::SpreadAttribute { .. } => continue, + JsxAttribute::Attribute { + name: attr_name, + place, + } => { + if let Some(&node_idx) = functions.get(&place.identifier) { + let node = &mut nodes[node_idx]; + if node.generated_name.is_none() { + let element_name = match tag { + JsxTag::Builtin(builtin) => Some(builtin.name.clone()), + JsxTag::Place(tag_place) => { + names.get(&tag_place.identifier).cloned() + } + }; + let prop_name = match element_name { + None => attr_name.clone(), + Some(ref el_name) => { + format!("<{}>.{}", el_name, attr_name) + } + }; + node.generated_name = Some(prop_name); + functions.remove(&place.identifier); + } + } + } + } + } + } + _ => {} + } + } + } + + nodes +} + +/// Handle CallExpression / MethodCall to generate names for function arguments. +fn handle_call( + env: &Environment, + _func: &HirFunction, + callee_id: IdentifierId, + args: &[PlaceOrSpread], + functions: &mut HashMap, + names: &HashMap, + nodes: &mut Vec, +) { + let callee_ident = &env.identifiers[callee_id.0 as usize]; + let callee_ty = &env.types[callee_ident.type_.0 as usize]; + let hook_kind = env.get_hook_kind_for_type(callee_ty).ok().flatten(); + + let callee_name: String = if let Some(hk) = hook_kind { + if *hk != HookKind::Custom { + hk.to_string() + } else { + names + .get(&callee_id) + .cloned() + .unwrap_or_else(|| "(anonymous)".to_string()) + } + } else { + names + .get(&callee_id) + .cloned() + .unwrap_or_else(|| "(anonymous)".to_string()) + }; + + // Count how many args are tracked functions + let fn_arg_count = args + .iter() + .filter(|arg| { + if let PlaceOrSpread::Place(p) = arg { + functions.contains_key(&p.identifier) + } else { + false + } + }) + .count(); + + for (i, arg) in args.iter().enumerate() { + let place = match arg { + PlaceOrSpread::Spread(_) => continue, + PlaceOrSpread::Place(p) => p, + }; + if let Some(&node_idx) = functions.get(&place.identifier) { + let node = &mut nodes[node_idx]; + if node.generated_name.is_none() { + let generated_name = if fn_arg_count > 1 { + format!("{}(arg{})", callee_name, i) + } else { + format!("{}()", callee_name) + }; + node.generated_name = Some(generated_name); + functions.remove(&place.identifier); + } + } + } +} diff --git a/crates/react_compiler_optimization/src/optimize_for_ssr.rs b/crates/react_compiler_optimization/src/optimize_for_ssr.rs new file mode 100644 index 000000000000..4e812d24725a --- /dev/null +++ b/crates/react_compiler_optimization/src/optimize_for_ssr.rs @@ -0,0 +1,359 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Optimizes the code for running in an SSR environment. +//! +//! Assumes that setState will not be called during render during initial mount, +//! which allows inlining useState/useReducer. +//! +//! Optimizations: +//! - Inline useState/useReducer +//! - Remove effects (useEffect, useLayoutEffect, useInsertionEffect) +//! - Remove event handlers (functions that call setState or startTransition) +//! - Remove known event handler props and ref props from builtin JSX tags +//! - Inline useEffectEvent to its argument +//! +//! Ported from TypeScript `src/Optimization/OptimizeForSSR.ts`. + +use std::collections::HashMap; + +use react_compiler_hir::{ + environment::Environment, + is_set_state_type, is_start_transition_type, + object_shape::HookKind, + visitors::{each_instruction_value_operand, each_terminal_operand}, + ArrayPatternElement, HirFunction, IdentifierId, InstructionValue, PlaceOrSpread, + PrimitiveValue, +}; + +/// Optimizes a function for SSR by inlining state hooks, removing effects, +/// removing event handlers, and stripping known event handler / ref JSX props. +/// +/// Corresponds to TS `optimizeForSSR(fn: HIRFunction): void`. +pub fn optimize_for_ssr(func: &mut HirFunction, env: &Environment) { + // Phase 1: Identify useState/useReducer calls that can be safely inlined. + // + // For useState(initialValue) where initialValue is primitive/object/array, + // store a LoadLocal of the initial value. + // + // For useReducer(reducer, initialArg) store a LoadLocal of initialArg. + // For useReducer(reducer, initialArg, init) store a CallExpression of + // init(initialArg). + // + // Any use of the hook return other than the expected destructuring pattern + // prevents inlining (we delete from inlined_state if we see the identifier used + // as an operand elsewhere). + let mut inlined_state: HashMap = HashMap::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::Destructure { value, lvalue, .. } => { + if inlined_state.contains_key(&env.identifiers[value.identifier.0 as usize].id) + { + if let react_compiler_hir::Pattern::Array(arr) = &lvalue.pattern { + if !arr.items.is_empty() { + if let ArrayPatternElement::Place(_) = &arr.items[0] { + // Allow destructuring of inlined states + continue; + } + } + } + } + } + InstructionValue::MethodCall { property, args, .. } + | InstructionValue::CallExpression { + callee: property, + args, + .. + } => { + // Determine callee based on instruction kind + let callee_id = property.identifier; + let hook_kind = get_hook_kind(env, callee_id); + match hook_kind { + Some(HookKind::UseReducer) => { + if args.len() == 2 { + if let (PlaceOrSpread::Place(_), PlaceOrSpread::Place(arg)) = + (&args[0], &args[1]) + { + let lvalue_id = + env.identifiers[instr.lvalue.identifier.0 as usize].id; + inlined_state.insert( + lvalue_id, + InlinedStateReplacement::LoadLocal { + place: arg.clone(), + loc: arg.loc, + }, + ); + } + } else if args.len() == 3 { + if let ( + PlaceOrSpread::Place(_), + PlaceOrSpread::Place(arg), + PlaceOrSpread::Place(initializer), + ) = (&args[0], &args[1], &args[2]) + { + let lvalue_id = + env.identifiers[instr.lvalue.identifier.0 as usize].id; + let call_loc = instr.value.loc().copied(); + inlined_state.insert( + lvalue_id, + InlinedStateReplacement::CallExpression { + callee: initializer.clone(), + arg: arg.clone(), + loc: call_loc, + }, + ); + } + } + } + Some(HookKind::UseState) => { + if args.len() == 1 { + if let PlaceOrSpread::Place(arg) = &args[0] { + let arg_type = &env.types[env.identifiers + [arg.identifier.0 as usize] + .type_ + .0 + as usize]; + if react_compiler_hir::is_primitive_type(arg_type) + || react_compiler_hir::is_plain_object_type(arg_type) + || react_compiler_hir::is_array_type(arg_type) + { + let lvalue_id = + env.identifiers[instr.lvalue.identifier.0 as usize].id; + inlined_state.insert( + lvalue_id, + InlinedStateReplacement::LoadLocal { + place: arg.clone(), + loc: arg.loc, + }, + ); + } + } + } + } + _ => {} + } + } + _ => {} + } + + // Any use of useState/useReducer return besides destructuring prevents inlining + if !inlined_state.is_empty() { + let operands = each_instruction_value_operand(&instr.value, env); + for operand in &operands { + let id = env.identifiers[operand.identifier.0 as usize].id; + inlined_state.remove(&id); + } + } + } + if !inlined_state.is_empty() { + let operands = each_terminal_operand(&block.terminal); + for operand in &operands { + let id = env.identifiers[operand.identifier.0 as usize].id; + inlined_state.remove(&id); + } + } + } + + // Phase 2: Apply transformations + // + // - Replace FunctionExpression with Primitive(undefined) if it calls + // setState/startTransition + // - Remove known event handler props and ref props from builtin JSX tags + // - Replace Destructure of inlined state with StoreLocal + // - Replace useEffectEvent(fn) with LoadLocal(fn) + // - Replace useEffect/useLayoutEffect/useInsertionEffect with + // Primitive(undefined) + // - Replace useState/useReducer with their inlined replacement + for (_block_id, block) in &mut func.body.blocks { + for &instr_id in &block.instructions { + let instr = &mut func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { + lowered_func, loc, .. + } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + if has_known_non_render_call(inner_func, env) { + let loc = *loc; + instr.value = InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc, + }; + } + } + InstructionValue::JsxExpression { tag, .. } => { + if let react_compiler_hir::JsxTag::Builtin(builtin) = tag { + // Only optimize non-custom-element builtin tags + if !builtin.name.contains('-') { + let tag_name = builtin.name.clone(); + // Retain only props that are not known event handlers and not "ref" + if let InstructionValue::JsxExpression { props, .. } = &mut instr.value + { + props.retain(|prop| match prop { + react_compiler_hir::JsxAttribute::SpreadAttribute { + .. + } => true, + react_compiler_hir::JsxAttribute::Attribute { + name, .. + } => !is_known_event_handler(&tag_name, name) && name != "ref", + }); + } + } + } + } + InstructionValue::Destructure { value, lvalue, loc } => { + let value_id = env.identifiers[value.identifier.0 as usize].id; + if inlined_state.contains_key(&value_id) { + // Invariant: destructuring pattern must be ArrayPattern with at least one + // Identifier item + if let react_compiler_hir::Pattern::Array(arr) = &lvalue.pattern { + if !arr.items.is_empty() { + if let ArrayPatternElement::Place(first_place) = &arr.items[0] { + let loc = *loc; + let kind = lvalue.kind; + let store = InstructionValue::StoreLocal { + lvalue: react_compiler_hir::LValue { + place: first_place.clone(), + kind, + }, + value: value.clone(), + type_annotation: None, + loc, + }; + instr.value = store; + } + } + } + } + } + InstructionValue::MethodCall { + property, + args, + loc, + .. + } + | InstructionValue::CallExpression { + callee: property, + args, + loc, + .. + } => { + let callee_id = property.identifier; + let hook_kind = get_hook_kind(env, callee_id); + match hook_kind { + Some(HookKind::UseEffectEvent) => { + if args.len() == 1 { + if let PlaceOrSpread::Place(arg) = &args[0] { + let loc = *loc; + instr.value = InstructionValue::LoadLocal { + place: arg.clone(), + loc, + }; + } + } + } + Some( + HookKind::UseEffect + | HookKind::UseLayoutEffect + | HookKind::UseInsertionEffect, + ) => { + let loc = *loc; + instr.value = InstructionValue::Primitive { + value: PrimitiveValue::Undefined, + loc, + }; + } + Some(HookKind::UseReducer | HookKind::UseState) => { + let lvalue_id = env.identifiers[instr.lvalue.identifier.0 as usize].id; + if let Some(replacement) = inlined_state.get(&lvalue_id) { + instr.value = match replacement { + InlinedStateReplacement::LoadLocal { place, loc } => { + InstructionValue::LoadLocal { + place: place.clone(), + loc: *loc, + } + } + InlinedStateReplacement::CallExpression { + callee, + arg, + loc, + } => InstructionValue::CallExpression { + callee: callee.clone(), + args: vec![PlaceOrSpread::Place(arg.clone())], + loc: *loc, + }, + }; + } + } + _ => {} + } + } + _ => {} + } + } + } +} + +/// Replacement values for inlined useState/useReducer calls. +#[derive(Debug, Clone)] +enum InlinedStateReplacement { + /// Replace with `LoadLocal { place }` — used for useState and + /// useReducer(reducer, initialArg) + LoadLocal { + place: react_compiler_hir::Place, + loc: Option, + }, + /// Replace with `CallExpression { callee, args: [arg] }` — used for + /// useReducer(reducer, initialArg, init) + CallExpression { + callee: react_compiler_hir::Place, + arg: react_compiler_hir::Place, + loc: Option, + }, +} + +/// Returns true if the function body contains a call to setState or +/// startTransition. This identifies functions that are event handlers and can +/// be replaced with undefined during SSR. +/// +/// Corresponds to TS `hasKnownNonRenderCall(fn: HIRFunction): boolean`. +fn has_known_non_render_call(func: &HirFunction, env: &Environment) -> bool { + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + if let InstructionValue::CallExpression { callee, .. } = &instr.value { + let callee_type = + &env.types[env.identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if is_set_state_type(callee_type) || is_start_transition_type(callee_type) { + return true; + } + } + } + } + false +} + +/// Returns true if the prop name matches the known event handler pattern +/// `on[A-Z]`. +fn is_known_event_handler(_tag: &str, prop: &str) -> bool { + if prop.len() < 3 { + return false; + } + if !prop.starts_with("on") { + return false; + } + let third_char = prop.as_bytes()[2]; + third_char.is_ascii_uppercase() +} + +/// Get the hook kind for an identifier, if its type represents a hook. +fn get_hook_kind(env: &Environment, identifier_id: IdentifierId) -> Option { + env.get_hook_kind_for_id(identifier_id) + .ok() + .flatten() + .cloned() +} diff --git a/crates/react_compiler_optimization/src/optimize_props_method_calls.rs b/crates/react_compiler_optimization/src/optimize_props_method_calls.rs new file mode 100644 index 000000000000..1681335a4ba5 --- /dev/null +++ b/crates/react_compiler_optimization/src/optimize_props_method_calls.rs @@ -0,0 +1,55 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Converts `MethodCall` instructions on props objects into `CallExpression` +//! instructions. +//! +//! When the receiver of a method call is typed as the component's props object, +//! we can safely convert the method call `props.foo(args)` into a direct call +//! `foo(args)` using the property as the callee. This simplifies downstream +//! analysis by removing the receiver dependency. +//! +//! Analogous to TS `Optimization/OptimizePropsMethodCalls.ts`. + +use react_compiler_hir::{environment::Environment, is_props_type, HirFunction, InstructionValue}; + +pub fn optimize_props_method_calls(func: &mut HirFunction, env: &Environment) { + for (_block_id, block) in &func.body.blocks { + let instruction_ids: Vec<_> = block.instructions.clone(); + for instr_id in instruction_ids { + let instr = &mut func.instructions[instr_id.0 as usize]; + let should_replace = matches!( + &instr.value, + InstructionValue::MethodCall { receiver, .. } + if { + let identifier = &env.identifiers[receiver.identifier.0 as usize]; + let ty = &env.types[identifier.type_.0 as usize]; + is_props_type(ty) + } + ); + if should_replace { + // Take the old value out, replacing with a temporary. + // The if-let is guaranteed to match since we checked above. + let old = + std::mem::replace(&mut instr.value, InstructionValue::Debugger { loc: None }); + match old { + InstructionValue::MethodCall { + property, + args, + loc, + .. + } => { + instr.value = InstructionValue::CallExpression { + callee: property, + args, + loc, + }; + } + _ => unreachable!(), + } + } + } + } +} diff --git a/crates/react_compiler_optimization/src/outline_functions.rs b/crates/react_compiler_optimization/src/outline_functions.rs new file mode 100644 index 000000000000..813e05b11270 --- /dev/null +++ b/crates/react_compiler_optimization/src/outline_functions.rs @@ -0,0 +1,132 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Port of OutlineFunctions from TypeScript +//! (`Optimization/OutlineFunctions.ts`). +//! +//! Extracts anonymous function expressions that do not close over any local +//! variables into top-level outlined functions. The original instruction is +//! replaced with a `LoadGlobal` referencing the outlined function's generated +//! name. +//! +//! Conditional on `env.config.enable_function_outlining`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, FunctionId, HirFunction, IdentifierId, InstructionValue, + NonLocalBinding, +}; +use react_compiler_ssa::enter_ssa::placeholder_function; + +/// Outline anonymous function expressions that have no captured context +/// variables. +/// +/// Ported from TS `outlineFunctions` in `Optimization/OutlineFunctions.ts`. +pub fn outline_functions( + func: &mut HirFunction, + env: &mut Environment, + fbt_operands: &HashSet, +) { + // Collect per-instruction actions to maintain depth-first name allocation + // order. Each entry: (instr index, function_id to recurse into, + // should_outline) + enum Action { + /// Recurse into an inner function (FunctionExpression or ObjectMethod) + Recurse(FunctionId), + /// Recurse then outline a FunctionExpression + RecurseAndOutline { + instr_idx: usize, + function_id: FunctionId, + }, + } + + let mut actions: Vec = Vec::new(); + + for block in func.body.blocks.values() { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } => { + let inner_func = &env.functions[lowered_func.func.0 as usize]; + + // Check outlining conditions (TS only checks func.id === null, not name): + // 1. No captured context variables + // 2. Anonymous (no explicit id on the inner function) + // 3. Not an fbt operand + if inner_func.context.is_empty() + && inner_func.id.is_none() + && !fbt_operands.contains(&lvalue_id) + { + actions.push(Action::RecurseAndOutline { + instr_idx: instr_id.0 as usize, + function_id: lowered_func.func, + }); + } else { + actions.push(Action::Recurse(lowered_func.func)); + } + } + InstructionValue::ObjectMethod { lowered_func, .. } => { + // Recurse into object methods (but don't outline them) + actions.push(Action::Recurse(lowered_func.func)); + } + _ => {} + } + } + } + + // Process actions sequentially: for each instruction, recurse first + // (depth-first), then generate name and outline. This matches TS ordering + // where inner functions get names allocated before outer ones. + for action in actions { + match action { + Action::Recurse(function_id) => { + let mut inner_func = std::mem::replace( + &mut env.functions[function_id.0 as usize], + placeholder_function(), + ); + outline_functions(&mut inner_func, env, fbt_operands); + env.functions[function_id.0 as usize] = inner_func; + } + Action::RecurseAndOutline { + instr_idx, + function_id, + } => { + // First recurse into the inner function (depth-first) + let mut inner_func = std::mem::replace( + &mut env.functions[function_id.0 as usize], + placeholder_function(), + ); + outline_functions(&mut inner_func, env, fbt_operands); + env.functions[function_id.0 as usize] = inner_func; + + // Then generate the name and outline (after recursion, matching TS order) + let hint: Option = env.functions[function_id.0 as usize] + .id + .clone() + .or_else(|| env.functions[function_id.0 as usize].name_hint.clone()); + let generated_name = env.generate_globally_unique_identifier_name(hint.as_deref()); + + // Set the id on the inner function + env.functions[function_id.0 as usize].id = Some(generated_name.clone()); + + // Outline the function + let outlined_func = env.functions[function_id.0 as usize].clone(); + env.outline_function(outlined_func, None); + + // Replace the instruction value with LoadGlobal + let loc = func.instructions[instr_idx].value.loc().cloned(); + func.instructions[instr_idx].value = InstructionValue::LoadGlobal { + binding: NonLocalBinding::Global { + name: generated_name, + }, + loc, + }; + } + } + } +} diff --git a/crates/react_compiler_optimization/src/outline_jsx.rs b/crates/react_compiler_optimization/src/outline_jsx.rs new file mode 100644 index 000000000000..ff4741196086 --- /dev/null +++ b/crates/react_compiler_optimization/src/outline_jsx.rs @@ -0,0 +1,673 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Port of OutlineJsx from TypeScript. +//! +//! Outlines JSX expressions in callbacks into separate component functions. +//! This pass is conditional on `env.config.enable_jsx_outlining` (defaults to +//! false). + +use std::collections::{HashMap, HashSet}; + +use indexmap::IndexMap; +use react_compiler_hir::{ + environment::Environment, BasicBlock, BlockId, BlockKind, EvaluationOrder, FunctionId, + HirFunction, IdentifierId, IdentifierName, Instruction, InstructionId, InstructionKind, + InstructionValue, JsxAttribute, JsxTag, LValuePattern, NonLocalBinding, ObjectPattern, + ObjectProperty, ObjectPropertyKey, ObjectPropertyOrSpread, ObjectPropertyType, ParamPattern, + Pattern, Place, ReactFunctionType, ReturnVariant, Terminal, HIR, +}; + +/// Outline JSX expressions in inner functions into separate outlined +/// components. +/// +/// Ported from TS `outlineJSX` in `Optimization/OutlineJsx.ts`. +pub fn outline_jsx(func: &mut HirFunction, env: &mut Environment) { + let mut outlined_fns: Vec = Vec::new(); + outline_jsx_impl(func, env, &mut outlined_fns); + + for outlined_fn in outlined_fns { + env.outline_function(outlined_fn, Some(ReactFunctionType::Component)); + } +} + +/// Data about a JSX instruction for outlining +struct JsxInstrInfo { + instr_idx: usize, // index into func.instructions + #[allow(dead_code)] + instr_id: InstructionId, // the InstructionId + lvalue_id: IdentifierId, + eval_order: EvaluationOrder, +} + +struct OutlinedJsxAttribute { + original_name: String, + new_name: String, + place: Place, +} + +struct OutlinedResult { + instrs: Vec, + func: HirFunction, +} + +fn outline_jsx_impl( + func: &mut HirFunction, + env: &mut Environment, + outlined_fns: &mut Vec, +) { + // Collect LoadGlobal instructions (tag -> instr) + let mut globals: HashMap = HashMap::new(); // id -> instr_idx + + // Process each block + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + for block_id in &block_ids { + let block = &func.body.blocks[block_id]; + let instr_ids = block.instructions.clone(); + + let mut rewrite_instr: HashMap> = HashMap::new(); + let mut jsx_group: Vec = Vec::new(); + let mut children_ids: HashSet = HashSet::new(); + + // First pass: collect all instruction info without borrowing func mutably + enum InstrAction { + LoadGlobal { + lvalue_id: IdentifierId, + instr_idx: usize, + }, + FunctionExpr { + func_id: FunctionId, + }, + JsxExpr { + lvalue_id: IdentifierId, + instr_idx: usize, + eval_order: EvaluationOrder, + child_ids: Vec, + }, + Other, + } + + let mut actions: Vec = Vec::new(); + for i in (0..instr_ids.len()).rev() { + let iid = instr_ids[i]; + let instr = &func.instructions[iid.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + InstructionValue::LoadGlobal { .. } => { + actions.push(InstrAction::LoadGlobal { + lvalue_id, + instr_idx: iid.0 as usize, + }); + } + InstructionValue::FunctionExpression { lowered_func, .. } => { + actions.push(InstrAction::FunctionExpr { + func_id: lowered_func.func, + }); + } + InstructionValue::JsxExpression { children, .. } => { + let child_ids = children + .as_ref() + .map(|kids| kids.iter().map(|c| c.identifier).collect()) + .unwrap_or_default(); + actions.push(InstrAction::JsxExpr { + lvalue_id, + instr_idx: iid.0 as usize, + eval_order: instr.id, + child_ids, + }); + } + _ => { + actions.push(InstrAction::Other); + } + } + } + + // Second pass: process actions + for action in actions { + match action { + InstrAction::LoadGlobal { + lvalue_id, + instr_idx, + } => { + globals.insert(lvalue_id, instr_idx); + } + InstrAction::FunctionExpr { func_id } => { + let mut inner_func = std::mem::replace( + &mut env.functions[func_id.0 as usize], + react_compiler_ssa::enter_ssa::placeholder_function(), + ); + outline_jsx_impl(&mut inner_func, env, outlined_fns); + env.functions[func_id.0 as usize] = inner_func; + } + InstrAction::JsxExpr { + lvalue_id, + instr_idx, + eval_order, + child_ids, + } => { + if !children_ids.contains(&lvalue_id) { + process_and_outline_jsx( + func, + env, + &mut jsx_group, + &globals, + &mut rewrite_instr, + outlined_fns, + ); + jsx_group.clear(); + children_ids.clear(); + } + jsx_group.push(JsxInstrInfo { + instr_idx, + instr_id: InstructionId(instr_idx as u32), + lvalue_id, + eval_order, + }); + for child_id in child_ids { + children_ids.insert(child_id); + } + } + InstrAction::Other => {} + } + } + // Process remaining JSX group after the loop + process_and_outline_jsx( + func, + env, + &mut jsx_group, + &globals, + &mut rewrite_instr, + outlined_fns, + ); + if !rewrite_instr.is_empty() { + let block = func.body.blocks.get_mut(block_id).unwrap(); + let old_instr_ids = block.instructions.clone(); + let mut new_instr_ids = Vec::new(); + for &iid in &old_instr_ids { + let eval_order = func.instructions[iid.0 as usize].id; + if let Some(replacement_instrs) = rewrite_instr.get(&eval_order) { + // Add replacement instructions to the instruction table and reference them + for new_instr in replacement_instrs { + let new_idx = func.instructions.len(); + func.instructions.push(new_instr.clone()); + new_instr_ids.push(InstructionId(new_idx as u32)); + } + } else { + new_instr_ids.push(iid); + } + } + let block = func.body.blocks.get_mut(block_id).unwrap(); + block.instructions = new_instr_ids; + + // Run dead code elimination after rewriting + super::dead_code_elimination(func, env); + } + } +} + +fn process_and_outline_jsx( + func: &mut HirFunction, + env: &mut Environment, + jsx_group: &mut Vec, + globals: &HashMap, + rewrite_instr: &mut HashMap>, + outlined_fns: &mut Vec, +) { + if jsx_group.len() <= 1 { + return; + } + // Sort by eval order ascending (TS: sort by a.id - b.id) + jsx_group.sort_by_key(|j| j.eval_order); + + let result = process_jsx_group(func, env, jsx_group, globals); + if let Some(result) = result { + outlined_fns.push(result.func); + // Map from the LAST JSX instruction's eval order to the replacement + // instructions In the TS code, `state.jsx.at(0)` is the first element + // pushed during reverse iteration, which is the last JSX in forward + // block order (highest eval order). After sorting by eval_order + // ascending, that's jsx_group.last(). + let last_eval_order = jsx_group.last().unwrap().eval_order; + rewrite_instr.insert(last_eval_order, result.instrs); + } +} + +fn process_jsx_group( + func: &HirFunction, + env: &mut Environment, + jsx_group: &[JsxInstrInfo], + globals: &HashMap, +) -> Option { + // Only outline in callbacks, not top-level components + if func.fn_type == ReactFunctionType::Component { + return None; + } + + let props = collect_props(func, env, jsx_group)?; + + let outlined_tag = env.generate_globally_unique_identifier_name(None); + let new_instrs = emit_outlined_jsx(func, env, jsx_group, &props, &outlined_tag)?; + let outlined_fn = emit_outlined_fn(func, env, jsx_group, &props, globals)?; + + // Set the outlined function's id + let mut outlined_fn = outlined_fn; + outlined_fn.id = Some(outlined_tag); + + Some(OutlinedResult { + instrs: new_instrs, + func: outlined_fn, + }) +} + +fn collect_props( + func: &HirFunction, + env: &mut Environment, + jsx_group: &[JsxInstrInfo], +) -> Option> { + let mut id_counter = 1u32; + let mut seen: HashSet = HashSet::new(); + let mut attributes = Vec::new(); + let jsx_ids: HashSet = jsx_group.iter().map(|j| j.lvalue_id).collect(); + + let mut generate_name = |old_name: &str, _env: &mut Environment| -> String { + let mut new_name = old_name.to_string(); + while seen.contains(&new_name) { + new_name = format!("{}{}", old_name, id_counter); + id_counter += 1; + } + seen.insert(new_name.clone()); + // TS: env.programContext.addNewReference(newName) + // We don't have programContext in Rust, but this is needed for unique name + // tracking + new_name + }; + + for info in jsx_group { + let instr = &func.instructions[info.instr_idx]; + if let InstructionValue::JsxExpression { + props, children, .. + } = &instr.value + { + for attr in props { + match attr { + JsxAttribute::SpreadAttribute { .. } => return None, + JsxAttribute::Attribute { name, place } => { + let new_name = generate_name(name, env); + attributes.push(OutlinedJsxAttribute { + original_name: name.clone(), + new_name, + place: place.clone(), + }); + } + } + } + + if let Some(kids) = children { + for child in kids { + if jsx_ids.contains(&child.identifier) { + continue; + } + // Promote the child's identifier to a named temporary + let child_id = child.identifier; + let decl_id = env.identifiers[child_id.0 as usize].declaration_id; + if env.identifiers[child_id.0 as usize].name.is_none() { + env.identifiers[child_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); + } + + let child_name = match &env.identifiers[child_id.0 as usize].name { + Some(IdentifierName::Named(n)) => n.clone(), + Some(IdentifierName::Promoted(n)) => n.clone(), + None => format!("#t{}", decl_id.0), + }; + let new_name = generate_name("t", env); + attributes.push(OutlinedJsxAttribute { + original_name: child_name, + new_name, + place: child.clone(), + }); + } + } + } + } + + Some(attributes) +} + +fn emit_outlined_jsx( + func: &HirFunction, + env: &mut Environment, + jsx_group: &[JsxInstrInfo], + outlined_props: &[OutlinedJsxAttribute], + outlined_tag: &str, +) -> Option> { + let props: Vec = outlined_props + .iter() + .map(|p| JsxAttribute::Attribute { + name: p.new_name.clone(), + place: p.place.clone(), + }) + .collect(); + + // Create LoadGlobal for the outlined component + let load_id = env.next_identifier_id(); + // Promote it as a JSX tag temporary + let decl_id = env.identifiers[load_id.0 as usize].declaration_id; + env.identifiers[load_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#T{}", decl_id.0))); + + let load_place = Place { + identifier: load_id, + effect: react_compiler_hir::Effect::Unknown, + reactive: false, + loc: None, + }; + + let load_jsx = Instruction { + id: EvaluationOrder(0), + lvalue: load_place.clone(), + value: InstructionValue::LoadGlobal { + binding: NonLocalBinding::ModuleLocal { + name: outlined_tag.to_string(), + }, + loc: None, + }, + loc: None, + effects: None, + }; + + // Create the replacement JsxExpression using the last JSX instruction's lvalue + let last_info = jsx_group.last().unwrap(); + let last_instr = &func.instructions[last_info.instr_idx]; + let jsx_expr = Instruction { + id: EvaluationOrder(0), + lvalue: last_instr.lvalue.clone(), + value: InstructionValue::JsxExpression { + tag: JsxTag::Place(load_place), + props, + children: None, + loc: None, + opening_loc: None, + closing_loc: None, + }, + loc: None, + effects: None, + }; + + Some(vec![load_jsx, jsx_expr]) +} + +fn emit_outlined_fn( + func: &HirFunction, + env: &mut Environment, + jsx_group: &[JsxInstrInfo], + old_props: &[OutlinedJsxAttribute], + globals: &HashMap, +) -> Option { + let old_to_new_props = create_old_to_new_props_mapping(env, old_props); + + // Create props parameter + let props_obj_id = env.next_identifier_id(); + let decl_id = env.identifiers[props_obj_id.0 as usize].declaration_id; + env.identifiers[props_obj_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); + let props_obj = Place { + identifier: props_obj_id, + effect: react_compiler_hir::Effect::Unknown, + reactive: false, + loc: None, + }; + + // Create destructure instruction + let destructure_instr = emit_destructure_props(env, &props_obj, &old_to_new_props); + + // Emit load globals for JSX tags + let load_global_instrs = emit_load_globals(func, jsx_group, globals)?; + + // Emit updated JSX instructions + let updated_jsx_instrs = emit_updated_jsx(func, jsx_group, &old_to_new_props); + + // Build instructions list + let mut instructions = Vec::new(); + instructions.push(destructure_instr); + instructions.extend(load_global_instrs); + instructions.extend(updated_jsx_instrs); + + // Build instruction table and instruction IDs + let mut instr_table = Vec::new(); + let mut instr_ids = Vec::new(); + for instr in instructions { + let idx = instr_table.len(); + instr_table.push(instr); + instr_ids.push(InstructionId(idx as u32)); + } + + // Return terminal uses the last instruction's lvalue + let last_lvalue = instr_table.last().unwrap().lvalue.clone(); + + // Create return place + let returns_id = env.next_identifier_id(); + let returns_place = Place { + identifier: returns_id, + effect: react_compiler_hir::Effect::Unknown, + reactive: false, + loc: None, + }; + + let block = BasicBlock { + kind: BlockKind::Block, + id: BlockId(0), + instructions: instr_ids, + preds: indexmap::IndexSet::new(), + terminal: Terminal::Return { + value: last_lvalue, + return_variant: ReturnVariant::Explicit, + id: EvaluationOrder(0), + loc: None, + effects: None, + }, + phis: Vec::new(), + }; + + let mut blocks = IndexMap::new(); + blocks.insert(BlockId(0), block); + + let outlined_fn = HirFunction { + id: None, + name_hint: None, + fn_type: ReactFunctionType::Other, + params: vec![ParamPattern::Place(props_obj)], + return_type_annotation: None, + returns: returns_place, + context: Vec::new(), + body: HIR { + entry: BlockId(0), + blocks, + }, + instructions: instr_table, + generator: false, + is_async: false, + directives: Vec::new(), + aliasing_effects: Some(vec![]), + loc: None, + }; + + Some(outlined_fn) +} + +fn emit_load_globals( + func: &HirFunction, + jsx_group: &[JsxInstrInfo], + globals: &HashMap, +) -> Option> { + let mut instructions = Vec::new(); + for info in jsx_group { + let instr = &func.instructions[info.instr_idx]; + if let InstructionValue::JsxExpression { tag, .. } = &instr.value { + if let JsxTag::Place(tag_place) = tag { + let global_instr_idx = globals.get(&tag_place.identifier)?; + instructions.push(func.instructions[*global_instr_idx].clone()); + } + } + } + Some(instructions) +} + +fn emit_updated_jsx( + func: &HirFunction, + jsx_group: &[JsxInstrInfo], + old_to_new_props: &IndexMap, +) -> Vec { + let jsx_ids: HashSet = jsx_group.iter().map(|j| j.lvalue_id).collect(); + let mut new_instrs = Vec::new(); + + for info in jsx_group { + let instr = &func.instructions[info.instr_idx]; + if let InstructionValue::JsxExpression { + tag, + props, + children, + loc, + opening_loc, + closing_loc, + } = &instr.value + { + let mut new_props = Vec::new(); + for prop in props { + // TS: invariant(prop.kind === 'JsxAttribute', ...) + // Spread attributes would have caused collectProps to return null earlier + let (name, place) = match prop { + JsxAttribute::Attribute { name, place } => (name, place), + JsxAttribute::SpreadAttribute { .. } => { + unreachable!("Expected only JsxAttribute, not spread") + } + }; + if name == "key" { + continue; + } + // TS: invariant(newProp !== undefined, ...) + let new_prop = old_to_new_props + .get(&place.identifier) + .expect("Expected a new property for identifier"); + new_props.push(JsxAttribute::Attribute { + name: new_prop.original_name.clone(), + place: new_prop.place.clone(), + }); + } + + let new_children = children.as_ref().map(|kids| { + kids.iter() + .map(|child| { + if jsx_ids.contains(&child.identifier) { + child.clone() + } else { + // TS: invariant(newChild !== undefined, ...) + let new_prop = old_to_new_props + .get(&child.identifier) + .expect("Expected a new prop for child identifier"); + new_prop.place.clone() + } + }) + .collect() + }); + + new_instrs.push(Instruction { + id: instr.id, + lvalue: instr.lvalue.clone(), + value: InstructionValue::JsxExpression { + tag: tag.clone(), + props: new_props, + children: new_children, + loc: *loc, + opening_loc: *opening_loc, + closing_loc: *closing_loc, + }, + loc: instr.loc, + effects: instr.effects.clone(), + }); + } + } + + new_instrs +} + +fn create_old_to_new_props_mapping( + env: &mut Environment, + old_props: &[OutlinedJsxAttribute], +) -> IndexMap { + let mut old_to_new = IndexMap::new(); + + for old_prop in old_props { + if old_prop.original_name == "key" { + continue; + } + + let new_id = env.next_identifier_id(); + env.identifiers[new_id.0 as usize].name = + Some(IdentifierName::Named(old_prop.new_name.clone())); + + let new_place = Place { + identifier: new_id, + effect: react_compiler_hir::Effect::Unknown, + reactive: false, + loc: None, + }; + + old_to_new.insert( + old_prop.place.identifier, + OutlinedJsxAttribute { + original_name: old_prop.original_name.clone(), + new_name: old_prop.new_name.clone(), + place: new_place, + }, + ); + } + + old_to_new +} + +fn emit_destructure_props( + env: &mut Environment, + props_obj: &Place, + old_to_new_props: &IndexMap, +) -> Instruction { + let mut properties = Vec::new(); + for prop in old_to_new_props.values() { + properties.push(ObjectPropertyOrSpread::Property(ObjectProperty { + key: ObjectPropertyKey::String { + name: prop.new_name.clone(), + }, + property_type: ObjectPropertyType::Property, + place: prop.place.clone(), + })); + } + + let lvalue_id = env.next_identifier_id(); + let lvalue = Place { + identifier: lvalue_id, + effect: react_compiler_hir::Effect::Unknown, + reactive: false, + loc: None, + }; + + Instruction { + id: EvaluationOrder(0), + lvalue, + value: InstructionValue::Destructure { + lvalue: LValuePattern { + pattern: Pattern::Object(ObjectPattern { + properties, + loc: None, + }), + kind: InstructionKind::Let, + }, + value: props_obj.clone(), + loc: None, + }, + loc: None, + effects: None, + } +} diff --git a/crates/react_compiler_optimization/src/prune_maybe_throws.rs b/crates/react_compiler_optimization/src/prune_maybe_throws.rs new file mode 100644 index 000000000000..95916f6f4442 --- /dev/null +++ b/crates/react_compiler_optimization/src/prune_maybe_throws.rs @@ -0,0 +1,135 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Prunes `MaybeThrow` terminals for blocks that can provably never throw. +//! +//! Currently very conservative: only affects blocks with primitives or +//! array/object literals. Even a variable reference could throw due to TDZ. +//! +//! Analogous to TS `Optimization/PruneMaybeThrows.ts`. + +use std::collections::HashMap; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory, GENERATED_SOURCE, +}; +use react_compiler_hir::{BlockId, HirFunction, Instruction, InstructionValue, Terminal}; +use react_compiler_lowering::{ + get_reverse_postordered_blocks, mark_instruction_ids, remove_dead_do_while_statements, + remove_unnecessary_try_catch, remove_unreachable_for_updates, +}; + +use crate::merge_consecutive_blocks::merge_consecutive_blocks; + +/// Prune `MaybeThrow` terminals for blocks that cannot throw, then clean up the +/// CFG. +pub fn prune_maybe_throws( + func: &mut HirFunction, + functions: &mut [HirFunction], +) -> Result<(), CompilerDiagnostic> { + let terminal_mapping = prune_maybe_throws_impl(func); + if let Some(terminal_mapping) = terminal_mapping { + // If terminals have changed then blocks may have become newly unreachable. + // Re-run minification of the graph (incl reordering instruction ids). + func.body.blocks = get_reverse_postordered_blocks(&func.body, &func.instructions); + remove_unreachable_for_updates(&mut func.body); + remove_dead_do_while_statements(&mut func.body); + remove_unnecessary_try_catch(&mut func.body); + mark_instruction_ids(&mut func.body, &mut func.instructions); + merge_consecutive_blocks(func, functions); + + // Rewrite phi operands to reference the updated predecessor blocks + for block in func.body.blocks.values_mut() { + let preds = &block.preds; + let mut phi_updates: Vec<(usize, Vec<(BlockId, BlockId)>)> = Vec::new(); + + for (phi_idx, phi) in block.phis.iter().enumerate() { + let mut updates = Vec::new(); + for (predecessor, _) in &phi.operands { + if !preds.contains(predecessor) { + let mapped_terminal = + terminal_mapping.get(predecessor).copied().ok_or_else(|| { + CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected non-existing phi operand's predecessor to have been \ + mapped to a new terminal", + Some(format!( + "Could not find mapping for predecessor bb{} in block bb{}", + predecessor.0, block.id.0, + )), + ) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: GENERATED_SOURCE, + message: None, + identifier_name: None, + }, + ) + })?; + updates.push((*predecessor, mapped_terminal)); + } + } + if !updates.is_empty() { + phi_updates.push((phi_idx, updates)); + } + } + + for (phi_idx, updates) in phi_updates { + for (old_pred, new_pred) in updates { + let operand = block.phis[phi_idx] + .operands + .shift_remove(&old_pred) + .unwrap(); + block.phis[phi_idx].operands.insert(new_pred, operand); + } + } + } + } + Ok(()) +} + +fn prune_maybe_throws_impl(func: &mut HirFunction) -> Option> { + let mut terminal_mapping: HashMap = HashMap::new(); + let instructions = &func.instructions; + + for block in func.body.blocks.values_mut() { + let continuation = match &block.terminal { + Terminal::MaybeThrow { continuation, .. } => *continuation, + _ => continue, + }; + + let can_throw = block + .instructions + .iter() + .any(|instr_id| instruction_may_throw(&instructions[instr_id.0 as usize])); + + if !can_throw { + let source = terminal_mapping.get(&block.id).copied().unwrap_or(block.id); + terminal_mapping.insert(continuation, source); + // Null out the handler rather than replacing with Goto. + // Preserving the MaybeThrow makes the continuations clear for + // BuildReactiveFunction, while nulling out the handler tells us + // that control cannot flow to the handler. + if let Terminal::MaybeThrow { handler, .. } = &mut block.terminal { + *handler = None; + } + } + } + + if terminal_mapping.is_empty() { + None + } else { + Some(terminal_mapping) + } +} + +fn instruction_may_throw(instr: &Instruction) -> bool { + match &instr.value { + InstructionValue::Primitive { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::ObjectExpression { .. } => false, + _ => true, + } +} diff --git a/crates/react_compiler_optimization/src/prune_unused_labels_hir.rs b/crates/react_compiler_optimization/src/prune_unused_labels_hir.rs new file mode 100644 index 000000000000..ba29fdda54aa --- /dev/null +++ b/crates/react_compiler_optimization/src/prune_unused_labels_hir.rs @@ -0,0 +1,105 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Removes unused labels from the HIR. +//! +//! A label terminal whose body block immediately breaks to the label's +//! fallthrough (with no other predecessors) is effectively a no-op label. +//! This pass merges such label/body/fallthrough triples into a single block. +//! +//! Analogous to TS `PruneUnusedLabelsHIR.ts`. + +use std::collections::HashMap; + +use react_compiler_hir::{BlockId, BlockKind, GotoVariant, HirFunction, Terminal}; + +pub fn prune_unused_labels_hir(func: &mut HirFunction) { + // Phase 1: Identify label terminals whose body block immediately breaks + // to the fallthrough, and both body and fallthrough are normal blocks. + let mut merged: Vec<(BlockId, BlockId, BlockId)> = Vec::new(); // (label, next, fallthrough) + + for (&block_id, block) in &func.body.blocks { + if let Terminal::Label { + block: next_id, + fallthrough: fallthrough_id, + .. + } = &block.terminal + { + let next = &func.body.blocks[next_id]; + let fallthrough = &func.body.blocks[fallthrough_id]; + if let Terminal::Goto { + block: goto_target, + variant: GotoVariant::Break, + .. + } = &next.terminal + { + if goto_target == fallthrough_id + && next.kind == BlockKind::Block + && fallthrough.kind == BlockKind::Block + { + merged.push((block_id, *next_id, *fallthrough_id)); + } + } + } + } + + // Phase 2: Apply merges + let mut rewrites: HashMap = HashMap::new(); + + for (original_label_id, next_id, fallthrough_id) in &merged { + let label_id = rewrites + .get(original_label_id) + .copied() + .unwrap_or(*original_label_id); + + // Validate: no phis in next or fallthrough + let next_phis_empty = func.body.blocks[next_id].phis.is_empty(); + let fallthrough_phis_empty = func.body.blocks[fallthrough_id].phis.is_empty(); + assert!( + next_phis_empty && fallthrough_phis_empty, + "Unexpected phis when merging label blocks" + ); + + // Validate: single predecessors + let next_preds_ok = func.body.blocks[next_id].preds.len() == 1 + && func.body.blocks[next_id].preds.contains(original_label_id); + let fallthrough_preds_ok = func.body.blocks[fallthrough_id].preds.len() == 1 + && func.body.blocks[fallthrough_id].preds.contains(next_id); + assert!( + next_preds_ok && fallthrough_preds_ok, + "Unexpected block predecessors when merging label blocks" + ); + + // Collect instructions from next and fallthrough + let next_instructions = func.body.blocks[next_id].instructions.clone(); + let fallthrough_instructions = func.body.blocks[fallthrough_id].instructions.clone(); + let fallthrough_terminal = func.body.blocks[fallthrough_id].terminal.clone(); + + // Merge into the label block + let label_block = func.body.blocks.get_mut(&label_id).unwrap(); + label_block.instructions.extend(next_instructions); + label_block.instructions.extend(fallthrough_instructions); + label_block.terminal = fallthrough_terminal; + + // Remove merged blocks + func.body.blocks.shift_remove(next_id); + func.body.blocks.shift_remove(fallthrough_id); + + rewrites.insert(*fallthrough_id, label_id); + } + + // Phase 3: Rewrite predecessor sets + for block in func.body.blocks.values_mut() { + let preds_to_rewrite: Vec<(BlockId, BlockId)> = block + .preds + .iter() + .filter_map(|pred| rewrites.get(pred).map(|rewritten| (*pred, *rewritten))) + .collect(); + for (old, new) in preds_to_rewrite { + block.preds.shift_remove(&old); + block.preds.insert(new); + } + } +} diff --git a/crates/react_compiler_reactive_scopes/Cargo.toml b/crates/react_compiler_reactive_scopes/Cargo.toml new file mode 100644 index 000000000000..7e0a191b79f7 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/Cargo.toml @@ -0,0 +1,16 @@ +[package] +description = "Vendored React Compiler reactive scopes from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_reactive_scopes" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_ast = { path = "../react_compiler_ast" } +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } +indexmap = { workspace = true } +serde_json = { workspace = true } +sha2 = { workspace = true } +hmac = "0.12" diff --git a/crates/react_compiler_reactive_scopes/src/assert_scope_instructions_within_scopes.rs b/crates/react_compiler_reactive_scopes/src/assert_scope_instructions_within_scopes.rs new file mode 100644 index 000000000000..9a723db0beba --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/assert_scope_instructions_within_scopes.rs @@ -0,0 +1,120 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Assert that all instructions involved in creating values for a given scope +//! are within the corresponding ReactiveScopeBlock. +//! +//! Corresponds to `src/ReactiveScopes/AssertScopeInstructionsWithinScope.ts`. + +use std::collections::HashSet; + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, EvaluationOrder, Place, ReactiveFunction, ReactiveScopeBlock, ScopeId, +}; + +use crate::visitors::{visit_reactive_function, ReactiveFunctionVisitor}; + +/// Assert that scope instructions are within their scopes. +/// Two-pass visitor: +/// 1. Collect all scope IDs +/// 2. Check that places referencing those scopes are within active scope blocks +pub fn assert_scope_instructions_within_scopes( + func: &ReactiveFunction, + env: &Environment, +) -> Result<(), CompilerDiagnostic> { + // Pass 1: Collect all scope IDs + let mut existing_scopes: HashSet = HashSet::new(); + let find_visitor = FindAllScopesVisitor { env }; + visit_reactive_function(func, &find_visitor, &mut existing_scopes); + + // Pass 2: Check instructions against scopes + let check_visitor = CheckInstructionsAgainstScopesVisitor { env }; + let mut check_state = CheckState { + existing_scopes, + active_scopes: HashSet::new(), + error: None, + }; + visit_reactive_function(func, &check_visitor, &mut check_state); + if let Some(err) = check_state.error { + return Err(err); + } + Ok(()) +} + +// ============================================================================= +// Pass 1: Find all scopes +// ============================================================================= + +struct FindAllScopesVisitor<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionVisitor for FindAllScopesVisitor<'a> { + type State = HashSet; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut HashSet) { + self.traverse_scope(scope, state); + state.insert(scope.scope); + } +} + +// ============================================================================= +// Pass 2: Check instructions against scopes +// ============================================================================= + +struct CheckState { + existing_scopes: HashSet, + active_scopes: HashSet, + error: Option, +} + +struct CheckInstructionsAgainstScopesVisitor<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionVisitor for CheckInstructionsAgainstScopesVisitor<'a> { + type State = CheckState; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_place(&self, id: EvaluationOrder, place: &Place, state: &mut CheckState) { + // getPlaceScope: check if the place's identifier has a scope that is active at + // this id + let identifier = &self.env.identifiers[place.identifier.0 as usize]; + if let Some(scope_id) = identifier.scope { + let scope = &self.env.scopes[scope_id.0 as usize]; + // isScopeActive: id >= scope.range.start && id < scope.range.end + let is_active_at_id = id >= scope.range.start && id < scope.range.end; + if is_active_at_id + && state.existing_scopes.contains(&scope_id) + && !state.active_scopes.contains(&scope_id) + { + state.error = Some(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Encountered an instruction that should be part of a scope, but where that \ + scope has already completed", + Some(format!( + "Instruction [{:?}] is part of scope @{:?}, but that scope has already \ + completed", + id, scope_id + )), + )); + } + } + } + + fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut CheckState) { + state.active_scopes.insert(scope.scope); + self.traverse_scope(scope, state); + state.active_scopes.remove(&scope.scope); + } +} diff --git a/crates/react_compiler_reactive_scopes/src/assert_well_formed_break_targets.rs b/crates/react_compiler_reactive_scopes/src/assert_well_formed_break_targets.rs new file mode 100644 index 000000000000..49602384e574 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/assert_well_formed_break_targets.rs @@ -0,0 +1,59 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Assert that all break/continue targets reference existent labels. +//! +//! Corresponds to `src/ReactiveScopes/AssertWellFormedBreakTargets.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, BlockId, ReactiveFunction, ReactiveTerminal, + ReactiveTerminalStatement, +}; + +use crate::visitors::{visit_reactive_function, ReactiveFunctionVisitor}; + +/// Assert that all break/continue targets reference existent labels. +pub fn assert_well_formed_break_targets(func: &ReactiveFunction, env: &Environment) { + let visitor = Visitor { env }; + let mut state: HashSet = HashSet::new(); + visit_reactive_function(func, &visitor, &mut state); +} + +struct Visitor<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionVisitor for Visitor<'a> { + type State = HashSet; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_terminal(&self, stmt: &ReactiveTerminalStatement, seen_labels: &mut HashSet) { + if let Some(label) = &stmt.label { + seen_labels.insert(label.id); + } + let terminal = &stmt.terminal; + match terminal { + ReactiveTerminal::Break { target, .. } | ReactiveTerminal::Continue { target, .. } => { + assert!( + seen_labels.contains(target), + "Unexpected break/continue to invalid label: {:?}", + target + ); + } + _ => {} + } + // Note: intentionally NOT calling self.traverse_terminal() here, + // matching TS behavior where visitTerminal override does not call + // traverseTerminal. Recursion into child blocks happens via + // traverseBlock→visitTerminal for nested blocks. The TS visitor + // only checks break/continue at the block level, not terminal child + // blocks. + } +} diff --git a/crates/react_compiler_reactive_scopes/src/build_reactive_function.rs b/crates/react_compiler_reactive_scopes/src/build_reactive_function.rs new file mode 100644 index 000000000000..1b6282d649a1 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/build_reactive_function.rs @@ -0,0 +1,1683 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Converts the HIR CFG into a tree-structured ReactiveFunction. +//! +//! Corresponds to `src/ReactiveScopes/BuildReactiveFunction.ts`. + +use std::collections::HashSet; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, BasicBlock, BlockId, EvaluationOrder, GotoVariant, HirFunction, + InstructionValue, Place, PrunedReactiveScopeBlock, ReactiveBlock, ReactiveFunction, + ReactiveInstruction, ReactiveLabel, ReactiveScopeBlock, ReactiveStatement, ReactiveSwitchCase, + ReactiveTerminal, ReactiveTerminalStatement, ReactiveTerminalTargetKind, ReactiveValue, + Terminal, +}; + +/// Convert the HIR CFG into a tree-structured ReactiveFunction. +pub fn build_reactive_function( + hir: &HirFunction, + env: &Environment, +) -> Result { + let mut ctx = Context::new(hir); + let mut driver = Driver { + cx: &mut ctx, + hir, + env, + }; + + let entry_block_id = hir.body.entry; + let mut body = Vec::new(); + driver.visit_block(entry_block_id, &mut body)?; + + Ok(ReactiveFunction { + loc: hir.loc, + id: hir.id.clone(), + name_hint: hir.name_hint.clone(), + params: hir.params.clone(), + generator: hir.generator, + is_async: hir.is_async, + body, + directives: hir.directives.clone(), + }) +} + +// ============================================================================= +// ControlFlowTarget +// ============================================================================= + +#[derive(Debug)] +enum ControlFlowTarget { + If { + block: BlockId, + id: u32, + }, + Switch { + block: BlockId, + id: u32, + }, + Case { + block: BlockId, + id: u32, + }, + Loop { + block: BlockId, + #[allow(dead_code)] + owns_block: bool, + continue_block: BlockId, + loop_block: Option, + owns_loop: bool, + id: u32, + }, +} + +impl ControlFlowTarget { + fn block(&self) -> BlockId { + match self { + ControlFlowTarget::If { block, .. } + | ControlFlowTarget::Switch { block, .. } + | ControlFlowTarget::Case { block, .. } + | ControlFlowTarget::Loop { block, .. } => *block, + } + } + + fn id(&self) -> u32 { + match self { + ControlFlowTarget::If { id, .. } + | ControlFlowTarget::Switch { id, .. } + | ControlFlowTarget::Case { id, .. } + | ControlFlowTarget::Loop { id, .. } => *id, + } + } + + fn is_loop(&self) -> bool { + matches!(self, ControlFlowTarget::Loop { .. }) + } +} + +// ============================================================================= +// Context +// ============================================================================= + +struct Context<'a> { + ir: &'a HirFunction, + next_schedule_id: u32, + emitted: HashSet, + scope_fallthroughs: HashSet, + scheduled: HashSet, + catch_handlers: HashSet, + control_flow_stack: Vec, +} + +impl<'a> Context<'a> { + fn new(ir: &'a HirFunction) -> Self { + Self { + ir, + next_schedule_id: 0, + emitted: HashSet::new(), + scope_fallthroughs: HashSet::new(), + scheduled: HashSet::new(), + catch_handlers: HashSet::new(), + control_flow_stack: Vec::new(), + } + } + + fn block(&self, id: BlockId) -> &BasicBlock { + &self.ir.body.blocks[&id] + } + + fn schedule_catch_handler(&mut self, block: BlockId) { + self.catch_handlers.insert(block); + } + + fn reachable(&self, id: BlockId) -> bool { + let block = self.block(id); + !matches!(block.terminal, Terminal::Unreachable { .. }) + } + + fn schedule(&mut self, block: BlockId, target_type: &str) -> Result { + let id = self.next_schedule_id; + self.next_schedule_id += 1; + if self.scheduled.contains(&block) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Break block is already scheduled: bb{}", block.0), + None, + )); + } + self.scheduled.insert(block); + let target = match target_type { + "if" => ControlFlowTarget::If { block, id }, + "switch" => ControlFlowTarget::Switch { block, id }, + "case" => ControlFlowTarget::Case { block, id }, + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Unknown target type: {}", target_type), + None, + )) + } + }; + self.control_flow_stack.push(target); + Ok(id) + } + + fn schedule_loop( + &mut self, + fallthrough_block: BlockId, + continue_block: BlockId, + loop_block: Option, + ) -> Result { + let id = self.next_schedule_id; + self.next_schedule_id += 1; + let owns_block = !self.scheduled.contains(&fallthrough_block); + self.scheduled.insert(fallthrough_block); + if self.scheduled.contains(&continue_block) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "Continue block is already scheduled: bb{}", + continue_block.0 + ), + None, + )); + } + self.scheduled.insert(continue_block); + let mut owns_loop = false; + if let Some(lb) = loop_block { + owns_loop = !self.scheduled.contains(&lb); + self.scheduled.insert(lb); + } + + self.control_flow_stack.push(ControlFlowTarget::Loop { + block: fallthrough_block, + owns_block, + continue_block, + loop_block, + owns_loop, + id, + }); + Ok(id) + } + + fn unschedule(&mut self, schedule_id: u32) -> Result<(), CompilerDiagnostic> { + let last = self + .control_flow_stack + .pop() + .expect("Can only unschedule the last target"); + if last.id() != schedule_id { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Can only unschedule the last target".to_string(), + None, + )); + } + match &last { + ControlFlowTarget::Loop { + block, + continue_block, + loop_block, + owns_loop, + .. + } => { + // TS: always removes block from scheduled for loops + // (ownsBlock is boolean, so `!== null` is always true) + self.scheduled.remove(block); + self.scheduled.remove(continue_block); + if *owns_loop { + if let Some(lb) = loop_block { + self.scheduled.remove(lb); + } + } + } + _ => { + self.scheduled.remove(&last.block()); + } + } + Ok(()) + } + + fn unschedule_all(&mut self, schedule_ids: &[u32]) -> Result<(), CompilerDiagnostic> { + for &id in schedule_ids.iter().rev() { + self.unschedule(id)?; + } + Ok(()) + } + + fn is_scheduled(&self, block: BlockId) -> bool { + self.scheduled.contains(&block) || self.catch_handlers.contains(&block) + } + + fn get_break_target( + &self, + block: BlockId, + ) -> Result<(BlockId, ReactiveTerminalTargetKind), CompilerDiagnostic> { + let mut has_preceding_loop = false; + for i in (0..self.control_flow_stack.len()).rev() { + let target = &self.control_flow_stack[i]; + if target.block() == block { + let kind = if target.is_loop() { + if has_preceding_loop { + ReactiveTerminalTargetKind::Labeled + } else { + ReactiveTerminalTargetKind::Unlabeled + } + } else if i == self.control_flow_stack.len() - 1 { + ReactiveTerminalTargetKind::Implicit + } else { + ReactiveTerminalTargetKind::Labeled + }; + return Ok((target.block(), kind)); + } + has_preceding_loop = has_preceding_loop || target.is_loop(); + } + Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Expected a break target for bb{}", block.0), + None, + )) + } + + fn get_continue_target(&self, block: BlockId) -> Option<(BlockId, ReactiveTerminalTargetKind)> { + let mut has_preceding_loop = false; + for i in (0..self.control_flow_stack.len()).rev() { + let target = &self.control_flow_stack[i]; + if let ControlFlowTarget::Loop { + block: fallthrough_block, + continue_block, + .. + } = target + { + if *continue_block == block { + let kind = if has_preceding_loop { + ReactiveTerminalTargetKind::Labeled + } else if i == self.control_flow_stack.len() - 1 { + ReactiveTerminalTargetKind::Implicit + } else { + ReactiveTerminalTargetKind::Unlabeled + }; + return Some((*fallthrough_block, kind)); + } + } + has_preceding_loop = has_preceding_loop || target.is_loop(); + } + None + } +} + +// ============================================================================= +// Driver +// ============================================================================= + +struct Driver<'a, 'b> { + cx: &'b mut Context<'a>, + hir: &'a HirFunction, + #[allow(dead_code)] + env: &'a Environment, +} + +impl<'a, 'b> Driver<'a, 'b> { + fn traverse_block(&mut self, block_id: BlockId) -> Result { + let mut block_value = Vec::new(); + self.visit_block(block_id, &mut block_value)?; + Ok(block_value) + } + + fn visit_block( + &mut self, + block_id: BlockId, + block_value: &mut ReactiveBlock, + ) -> Result<(), CompilerDiagnostic> { + // Extract data from block before any mutable operations + let block = &self.hir.body.blocks[&block_id]; + let block_id_val = block.id; + let instructions: Vec<_> = block.instructions.clone(); + let terminal = block.terminal.clone(); + + if !self.cx.emitted.insert(block_id_val) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Block bb{} was already emitted", block_id_val.0), + None, + )); + } + + // Emit instructions + for instr_id in &instructions { + let instr = &self.hir.instructions[instr_id.0 as usize]; + block_value.push(ReactiveStatement::Instruction(ReactiveInstruction { + id: instr.id, + lvalue: Some(instr.lvalue.clone()), + value: ReactiveValue::Instruction(instr.value.clone()), + effects: instr.effects.clone(), + loc: instr.loc, + })); + } + + // Process terminal + let mut schedule_ids: Vec = Vec::new(); + + match &terminal { + Terminal::If { + test, + consequent, + alternate, + fallthrough, + id, + loc, + } => { + // TS: reachable(fallthrough) && !isScheduled(fallthrough) + let fallthrough_id = + if self.cx.reachable(*fallthrough) && !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + // TS: alternate !== fallthrough ? alternate : null + let alternate_id = if *alternate != *fallthrough { + Some(*alternate) + } else { + None + }; + + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "if")?); + } + + let consequent_block = if self.cx.is_scheduled(*consequent) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "Unexpected 'if' where consequent is already scheduled (bb{})", + consequent.0 + ), + None, + )); + } else { + self.traverse_block(*consequent)? + }; + + let alternate_block = if let Some(alt) = alternate_id { + if self.cx.is_scheduled(alt) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "Unexpected 'if' where the alternate is already scheduled (bb{})", + alt.0 + ), + None, + )); + } else { + Some(self.traverse_block(alt)?) + } + } else { + None + }; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::If { + test: test.clone(), + consequent: consequent_block, + alternate: alternate_block, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::Switch { + test, + cases, + fallthrough, + id, + loc, + } => { + // TS: reachable(fallthrough) && !isScheduled(fallthrough) + let fallthrough_id = + if self.cx.reachable(*fallthrough) && !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "switch")?); + } + + // TS processes cases in reverse order, then reverses the result. + // This ensures that later cases are scheduled when earlier cases + // are traversed, matching fallthrough semantics. + let mut reactive_cases = Vec::new(); + for case in cases.iter().rev() { + let case_block_id = case.block; + + if self.cx.is_scheduled(case_block_id) { + // TS: asserts case.block === fallthrough, then skips (return) + if case_block_id != *fallthrough { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'switch' where a case is already scheduled and block \ + is not the fallthrough" + .to_string(), + None, + )); + } + continue; + } + + let consequent = self.traverse_block(case_block_id)?; + let case_schedule_id = self.cx.schedule(case_block_id, "case")?; + schedule_ids.push(case_schedule_id); + + reactive_cases.push(ReactiveSwitchCase { + test: case.test.clone(), + block: Some(consequent), + }); + } + reactive_cases.reverse(); + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Switch { + test: test.clone(), + cases: reactive_cases, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::DoWhile { + loop_block, + test, + fallthrough, + id, + loc, + } => { + let fallthrough_id = if !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + let loop_id = if !self.cx.is_scheduled(*loop_block) && *loop_block != *fallthrough { + Some(*loop_block) + } else { + None + }; + + schedule_ids.push( + self.cx + .schedule_loop(*fallthrough, *test, Some(*loop_block))?, + ); + + let loop_body = if let Some(lid) = loop_id { + self.traverse_block(lid)? + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'do-while' where the loop is already scheduled", + None, + )); + }; + let test_result = self.visit_value_block(*test, *loc, None)?; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::DoWhile { + loop_block: loop_body, + test: test_result.value, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::While { + test, + loop_block, + fallthrough, + id, + loc, + } => { + // TS: reachable(fallthrough) && !isScheduled(fallthrough) + let fallthrough_id = + if self.cx.reachable(*fallthrough) && !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + let loop_id = if !self.cx.is_scheduled(*loop_block) && *loop_block != *fallthrough { + Some(*loop_block) + } else { + None + }; + + schedule_ids.push( + self.cx + .schedule_loop(*fallthrough, *test, Some(*loop_block))?, + ); + + let test_result = self.visit_value_block(*test, *loc, None)?; + + let loop_body = if let Some(lid) = loop_id { + self.traverse_block(lid)? + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'while' where the loop is already scheduled", + None, + )); + }; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::While { + test: test_result.value, + loop_block: loop_body, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::For { + init, + test, + update, + loop_block, + fallthrough, + id, + loc, + } => { + let loop_id = if !self.cx.is_scheduled(*loop_block) && *loop_block != *fallthrough { + Some(*loop_block) + } else { + None + }; + + let fallthrough_id = if !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + + // Continue block is update (if present) or test + let continue_block = update.unwrap_or(*test); + schedule_ids.push(self.cx.schedule_loop( + *fallthrough, + continue_block, + Some(*loop_block), + )?); + + let init_result = self.visit_value_block(*init, *loc, None)?; + let init_value = self.value_block_result_to_sequence(init_result, *loc); + + let test_result = self.visit_value_block(*test, *loc, None)?; + + let update_result = match update { + Some(u) => Some(self.visit_value_block(*u, *loc, None)?), + None => None, + }; + + let loop_body = if let Some(lid) = loop_id { + self.traverse_block(lid)? + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'for' where the loop is already scheduled", + None, + )); + }; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::For { + init: init_value, + test: test_result.value, + update: update_result.map(|r| r.value), + loop_block: loop_body, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::ForOf { + init, + test, + loop_block, + fallthrough, + id, + loc, + } => { + let loop_id = if !self.cx.is_scheduled(*loop_block) && *loop_block != *fallthrough { + Some(*loop_block) + } else { + None + }; + + let fallthrough_id = if !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + + // TS: scheduleLoop(fallthrough, init, loop) + schedule_ids.push( + self.cx + .schedule_loop(*fallthrough, *init, Some(*loop_block))?, + ); + + let init_result = self.visit_value_block(*init, *loc, None)?; + let init_value = self.value_block_result_to_sequence(init_result, *loc); + + let test_result = self.visit_value_block(*test, *loc, None)?; + let test_value = self.value_block_result_to_sequence(test_result, *loc); + + let loop_body = if let Some(lid) = loop_id { + self.traverse_block(lid)? + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'for-of' where the loop is already scheduled", + None, + )); + }; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::ForOf { + init: init_value, + test: test_value, + loop_block: loop_body, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::ForIn { + init, + loop_block, + fallthrough, + id, + loc, + } => { + let loop_id = if !self.cx.is_scheduled(*loop_block) && *loop_block != *fallthrough { + Some(*loop_block) + } else { + None + }; + + let fallthrough_id = if !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + + schedule_ids.push( + self.cx + .schedule_loop(*fallthrough, *init, Some(*loop_block))?, + ); + + let init_result = self.visit_value_block(*init, *loc, None)?; + let init_value = self.value_block_result_to_sequence(init_result, *loc); + + let loop_body = if let Some(lid) = loop_id { + self.traverse_block(lid)? + } else { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'for-in' where the loop is already scheduled", + None, + )); + }; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::ForIn { + init: init_value, + loop_block: loop_body, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::Label { + block: label_block, + fallthrough, + id, + loc, + } => { + // TS: reachable(fallthrough) && !isScheduled(fallthrough) + let fallthrough_id = + if self.cx.reachable(*fallthrough) && !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "if")?); + } + + if self.cx.is_scheduled(*label_block) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'label' where the block is already scheduled".to_string(), + None, + )); + } + let label_body = self.traverse_block(*label_block)?; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Label { + block: label_body, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::Sequence { .. } + | Terminal::Optional { .. } + | Terminal::Ternary { .. } + | Terminal::Logical { .. } => { + let fallthrough = match &terminal { + Terminal::Sequence { fallthrough, .. } + | Terminal::Optional { fallthrough, .. } + | Terminal::Ternary { fallthrough, .. } + | Terminal::Logical { fallthrough, .. } => *fallthrough, + _ => unreachable!(), + }; + let fallthrough_id = if !self.cx.is_scheduled(fallthrough) { + Some(fallthrough) + } else { + None + }; + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "if")?); + } + + let result = self.visit_value_block_terminal(&terminal)?; + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Instruction(ReactiveInstruction { + id: result.id, + lvalue: Some(result.place), + value: result.value, + effects: None, + loc: *terminal_loc(&terminal), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::Goto { + block: goto_block, + variant, + id, + loc, + } => { + match variant { + GotoVariant::Break => { + if let Some(stmt) = self.visit_break(*goto_block, *id, *loc)? { + block_value.push(stmt); + } + } + GotoVariant::Continue => { + let stmt = self.visit_continue(*goto_block, *id, *loc)?; + block_value.push(stmt); + } + GotoVariant::Try => { + // noop + } + } + } + + Terminal::MaybeThrow { continuation, .. } => { + if !self.cx.is_scheduled(*continuation) { + self.visit_block(*continuation, block_value)?; + } + } + + Terminal::Try { + block: try_block, + handler_binding, + handler, + fallthrough, + id, + loc, + } => { + let fallthrough_id = + if self.cx.reachable(*fallthrough) && !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "if")?); + } + self.cx.schedule_catch_handler(*handler); + + let try_body = self.traverse_block(*try_block)?; + let handler_body = self.traverse_block(*handler)?; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Try { + block: try_body, + handler_binding: handler_binding.clone(), + handler: handler_body, + id: *id, + loc: *loc, + }, + label: fallthrough_id.map(|ft| ReactiveLabel { + id: ft, + implicit: false, + }), + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::Scope { + fallthrough, + block: scope_block, + scope, + .. + } => { + let fallthrough_id = if !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "if")?); + self.cx.scope_fallthroughs.insert(ft); + } + + if self.cx.is_scheduled(*scope_block) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'scope' where the block is already scheduled".to_string(), + None, + )); + } + let scope_body = self.traverse_block(*scope_block)?; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::Scope(ReactiveScopeBlock { + scope: *scope, + instructions: scope_body, + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::PrunedScope { + fallthrough, + block: scope_block, + scope, + .. + } => { + let fallthrough_id = if !self.cx.is_scheduled(*fallthrough) { + Some(*fallthrough) + } else { + None + }; + if let Some(ft) = fallthrough_id { + schedule_ids.push(self.cx.schedule(ft, "if")?); + self.cx.scope_fallthroughs.insert(ft); + } + + if self.cx.is_scheduled(*scope_block) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'scope' where the block is already scheduled".to_string(), + None, + )); + } + let scope_body = self.traverse_block(*scope_block)?; + + self.cx.unschedule_all(&schedule_ids)?; + block_value.push(ReactiveStatement::PrunedScope(PrunedReactiveScopeBlock { + scope: *scope, + instructions: scope_body, + })); + + if let Some(ft) = fallthrough_id { + self.visit_block(ft, block_value)?; + } + } + + Terminal::Return { value, id, loc, .. } => { + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Return { + value: value.clone(), + id: *id, + loc: *loc, + }, + label: None, + })); + } + + Terminal::Throw { value, id, loc } => { + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Throw { + value: value.clone(), + id: *id, + loc: *loc, + }, + label: None, + })); + } + + Terminal::Unreachable { .. } => { + // noop + } + + Terminal::Unsupported { .. } => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected unsupported terminal", + None, + )); + } + + Terminal::Branch { + test, + consequent, + alternate, + id, + loc, + .. + } => { + let consequent_block = if self.cx.is_scheduled(*consequent) { + if let Some(stmt) = self.visit_break(*consequent, *id, *loc)? { + vec![stmt] + } else { + Vec::new() + } + } else { + self.traverse_block(*consequent)? + }; + + if self.cx.is_scheduled(*alternate) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected 'branch' where the alternate is already scheduled".to_string(), + None, + )); + } + let alternate_block = self.traverse_block(*alternate)?; + + block_value.push(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::If { + test: test.clone(), + consequent: consequent_block, + alternate: Some(alternate_block), + id: *id, + loc: *loc, + }, + label: None, + })); + } + } + Ok(()) + } + + // ========================================================================= + // Value block processing + // ========================================================================= + + fn visit_value_block( + &mut self, + block_id: BlockId, + loc: Option, + fallthrough: Option, + ) -> Result { + let block = &self.hir.body.blocks[&block_id]; + let block_id_val = block.id; + let terminal = block.terminal.clone(); + let instructions: Vec<_> = block.instructions.clone(); + + // If we've reached the fallthrough, stop + if let Some(ft) = fallthrough { + if block_id == ft { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "Did not expect to reach the fallthrough of a value block (bb{})", + block_id.0 + ), + None, + )); + } + } + + match &terminal { + Terminal::Branch { + test, id: term_id, .. + } => { + if instructions.is_empty() { + Ok(ValueBlockResult { + block: block_id_val, + place: test.clone(), + value: ReactiveValue::Instruction(InstructionValue::LoadLocal { + place: test.clone(), + loc: test.loc, + }), + id: *term_id, + }) + } else { + Ok(self.extract_value_block_result(&instructions, block_id_val, loc)) + } + } + Terminal::Goto { .. } => { + if instructions.is_empty() { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected empty block with `goto` terminal", + Some(format!("Block bb{} is empty", block_id.0)), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some("Unexpected empty block with `goto` terminal".to_string()), + identifier_name: None, + })); + } + Ok(self.extract_value_block_result(&instructions, block_id_val, loc)) + } + Terminal::MaybeThrow { continuation, .. } => { + let continuation_id = *continuation; + let continuation_block = self.cx.block(continuation_id); + let cont_instructions_empty = continuation_block.instructions.is_empty(); + let cont_is_goto = matches!(continuation_block.terminal, Terminal::Goto { .. }); + let cont_block_id = continuation_block.id; + + if cont_instructions_empty && cont_is_goto { + Ok(self.extract_value_block_result(&instructions, cont_block_id, loc)) + } else { + let continuation = self.visit_value_block(continuation_id, loc, fallthrough)?; + Ok(self.wrap_with_sequence(&instructions, continuation, loc)) + } + } + _ => { + // Value block ended in a value terminal, recurse to get the value + // of that terminal and stitch them together in a sequence. + // TS: visitValueBlock(init.fallthrough, loc) — does NOT propagate fallthrough + let init = self.visit_value_block_terminal(&terminal)?; + let init_fallthrough = init.fallthrough; + let init_instr = ReactiveInstruction { + id: init.id, + lvalue: Some(init.place), + value: init.value, + effects: None, + loc, + }; + let final_result = self.visit_value_block(init_fallthrough, loc, None)?; + + // Combine block instructions + init instruction, then wrap + let mut all_instrs: Vec = instructions + .iter() + .map(|iid| { + let instr = &self.hir.instructions[iid.0 as usize]; + ReactiveInstruction { + id: instr.id, + lvalue: Some(instr.lvalue.clone()), + value: ReactiveValue::Instruction(instr.value.clone()), + effects: instr.effects.clone(), + loc: instr.loc, + } + }) + .collect(); + all_instrs.push(init_instr); + + if all_instrs.is_empty() { + Ok(final_result) + } else { + Ok(ValueBlockResult { + block: final_result.block, + place: final_result.place.clone(), + value: ReactiveValue::SequenceExpression { + instructions: all_instrs, + id: final_result.id, + value: Box::new(final_result.value), + loc, + }, + id: final_result.id, + }) + } + } + } + } + + fn visit_test_block( + &mut self, + test_block_id: BlockId, + loc: Option, + terminal_kind: &str, + ) -> Result { + let test = self.visit_value_block(test_block_id, loc, None)?; + let test_block = &self.hir.body.blocks[&test.block]; + match &test_block.terminal { + Terminal::Branch { + consequent, + alternate, + loc: branch_loc, + .. + } => Ok(TestBlockResult { + test, + consequent: *consequent, + alternate: *alternate, + branch_loc: *branch_loc, + }), + other => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!( + "Expected a branch terminal for {} test block, got {:?}", + terminal_kind, + std::mem::discriminant(other) + ), + None, + )), + } + } + + fn visit_value_block_terminal( + &mut self, + terminal: &Terminal, + ) -> Result { + match terminal { + Terminal::Sequence { + block, + fallthrough, + id, + loc, + } => { + let block_result = self.visit_value_block(*block, *loc, Some(*fallthrough))?; + Ok(ValueTerminalResult { + value: block_result.value, + place: block_result.place, + fallthrough: *fallthrough, + id: *id, + }) + } + Terminal::Optional { + optional, + test, + fallthrough, + id, + loc, + } => { + let test_result = self.visit_test_block(*test, *loc, "optional")?; + let consequent = + self.visit_value_block(test_result.consequent, *loc, Some(*fallthrough))?; + let call = ReactiveValue::SequenceExpression { + instructions: vec![ReactiveInstruction { + id: test_result.test.id, + lvalue: Some(test_result.test.place.clone()), + value: test_result.test.value, + effects: None, + loc: test_result.branch_loc, + }], + id: consequent.id, + value: Box::new(consequent.value), + loc: *loc, + }; + Ok(ValueTerminalResult { + place: consequent.place, + value: ReactiveValue::OptionalExpression { + optional: *optional, + value: Box::new(call), + id: *id, + loc: *loc, + }, + fallthrough: *fallthrough, + id: *id, + }) + } + Terminal::Logical { + operator, + test, + fallthrough, + id, + loc, + } => { + let test_result = self.visit_test_block(*test, *loc, "logical")?; + let left_final = + self.visit_value_block(test_result.consequent, *loc, Some(*fallthrough))?; + let left = ReactiveValue::SequenceExpression { + instructions: vec![ReactiveInstruction { + id: test_result.test.id, + lvalue: Some(test_result.test.place.clone()), + value: test_result.test.value, + effects: None, + loc: *loc, + }], + id: left_final.id, + value: Box::new(left_final.value), + loc: *loc, + }; + let right = + self.visit_value_block(test_result.alternate, *loc, Some(*fallthrough))?; + Ok(ValueTerminalResult { + place: left_final.place, + value: ReactiveValue::LogicalExpression { + operator: *operator, + left: Box::new(left), + right: Box::new(right.value), + loc: *loc, + }, + fallthrough: *fallthrough, + id: *id, + }) + } + Terminal::Ternary { + test, + fallthrough, + id, + loc, + } => { + let test_result = self.visit_test_block(*test, *loc, "ternary")?; + let consequent = + self.visit_value_block(test_result.consequent, *loc, Some(*fallthrough))?; + let alternate = + self.visit_value_block(test_result.alternate, *loc, Some(*fallthrough))?; + Ok(ValueTerminalResult { + place: consequent.place, + value: ReactiveValue::ConditionalExpression { + test: Box::new(test_result.test.value), + consequent: Box::new(consequent.value), + alternate: Box::new(alternate.value), + loc: *loc, + }, + fallthrough: *fallthrough, + id: *id, + }) + } + Terminal::MaybeThrow { .. } => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected maybe-throw in visit_value_block_terminal", + None, + )), + Terminal::Label { .. } => Err(CompilerDiagnostic::new( + ErrorCategory::Todo, + "Support labeled statements combined with value blocks is not yet implemented", + None, + )), + _ => Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unsupported terminal kind in value block", + None, + )), + } + } + + fn extract_value_block_result( + &self, + instructions: &[react_compiler_hir::InstructionId], + block_id: BlockId, + loc: Option, + ) -> ValueBlockResult { + let last_id = instructions + .last() + .expect("Expected non-empty instructions"); + let last_instr = &self.hir.instructions[last_id.0 as usize]; + + let remaining: Vec = instructions[..instructions.len() - 1] + .iter() + .map(|iid| { + let instr = &self.hir.instructions[iid.0 as usize]; + ReactiveInstruction { + id: instr.id, + lvalue: Some(instr.lvalue.clone()), + value: ReactiveValue::Instruction(instr.value.clone()), + effects: instr.effects.clone(), + loc: instr.loc, + } + }) + .collect(); + + // If the last instruction is a StoreLocal to a temporary (unnamed identifier), + // convert it to a LoadLocal of the value being stored, matching the TS + // behavior. + let (value, place) = match &last_instr.value { + InstructionValue::StoreLocal { + lvalue, + value: store_value, + .. + } => { + let ident = &self.env.identifiers[lvalue.place.identifier.0 as usize]; + if ident.name.is_none() { + ( + ReactiveValue::Instruction(InstructionValue::LoadLocal { + place: store_value.clone(), + loc: store_value.loc, + }), + lvalue.place.clone(), + ) + } else { + ( + ReactiveValue::Instruction(last_instr.value.clone()), + last_instr.lvalue.clone(), + ) + } + } + _ => ( + ReactiveValue::Instruction(last_instr.value.clone()), + last_instr.lvalue.clone(), + ), + }; + let id = last_instr.id; + + if remaining.is_empty() { + ValueBlockResult { + block: block_id, + place, + value, + id, + } + } else { + ValueBlockResult { + block: block_id, + place: place.clone(), + value: ReactiveValue::SequenceExpression { + instructions: remaining, + id, + value: Box::new(value), + loc, + }, + id, + } + } + } + + fn wrap_with_sequence( + &self, + instructions: &[react_compiler_hir::InstructionId], + continuation: ValueBlockResult, + loc: Option, + ) -> ValueBlockResult { + if instructions.is_empty() { + return continuation; + } + + let reactive_instrs: Vec = instructions + .iter() + .map(|iid| { + let instr = &self.hir.instructions[iid.0 as usize]; + ReactiveInstruction { + id: instr.id, + lvalue: Some(instr.lvalue.clone()), + value: ReactiveValue::Instruction(instr.value.clone()), + effects: instr.effects.clone(), + loc: instr.loc, + } + }) + .collect(); + + ValueBlockResult { + block: continuation.block, + place: continuation.place.clone(), + value: ReactiveValue::SequenceExpression { + instructions: reactive_instrs, + id: continuation.id, + value: Box::new(continuation.value), + loc, + }, + id: continuation.id, + } + } + + /// Converts the result of visit_value_block into a SequenceExpression that + /// includes the instruction with its lvalue. This is needed for + /// for/for-of/for-in init/test blocks where the instruction's lvalue + /// assignment must be preserved. + /// + /// This also flattens nested SequenceExpressions that can occur from + /// MaybeThrow handling in try-catch blocks. + /// + /// TS: valueBlockResultToSequence() + fn value_block_result_to_sequence( + &self, + result: ValueBlockResult, + loc: Option, + ) -> ReactiveValue { + // Collect all instructions from potentially nested SequenceExpressions + let mut instructions: Vec = Vec::new(); + let mut inner_value = result.value; + + // Flatten nested SequenceExpressions + loop { + match inner_value { + ReactiveValue::SequenceExpression { + instructions: seq_instrs, + value, + .. + } => { + instructions.extend(seq_instrs); + inner_value = *value; + } + _ => break, + } + } + + // Only add the final instruction if the innermost value is not just a LoadLocal + // of the same place we're storing to (which would be a no-op). + let is_load_of_same_place = match &inner_value { + ReactiveValue::Instruction(InstructionValue::LoadLocal { place, .. }) => { + place.identifier == result.place.identifier + } + _ => false, + }; + + if !is_load_of_same_place { + instructions.push(ReactiveInstruction { + id: result.id, + lvalue: Some(result.place), + value: inner_value, + effects: None, + loc, + }); + } + + ReactiveValue::SequenceExpression { + instructions, + id: result.id, + value: Box::new(ReactiveValue::Instruction(InstructionValue::Primitive { + value: react_compiler_hir::PrimitiveValue::Undefined, + loc, + })), + loc, + } + } + + fn visit_break( + &self, + block: BlockId, + id: EvaluationOrder, + loc: Option, + ) -> Result, CompilerDiagnostic> { + let (target_block, target_kind) = self.cx.get_break_target(block)?; + if self.cx.scope_fallthroughs.contains(&target_block) { + if target_kind != ReactiveTerminalTargetKind::Implicit { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected reactive scope to implicitly break to fallthrough".to_string(), + None, + )); + } + return Ok(None); + } + Ok(Some(ReactiveStatement::Terminal( + ReactiveTerminalStatement { + terminal: ReactiveTerminal::Break { + target: target_block, + id, + target_kind, + loc, + }, + label: None, + }, + ))) + } + + fn visit_continue( + &self, + block: BlockId, + id: EvaluationOrder, + loc: Option, + ) -> Result { + let (target_block, target_kind) = match self.cx.get_continue_target(block) { + Some(result) => result, + None => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Expected continue target to be scheduled for bb{}", block.0), + None, + )); + } + }; + + Ok(ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Continue { + target: target_block, + id, + target_kind, + loc, + }, + label: None, + })) + } +} + +// ============================================================================= +// Helper types +// ============================================================================= + +struct ValueBlockResult { + block: BlockId, + place: Place, + value: ReactiveValue, + id: EvaluationOrder, +} + +struct TestBlockResult { + test: ValueBlockResult, + consequent: BlockId, + alternate: BlockId, + branch_loc: Option, +} + +struct ValueTerminalResult { + value: ReactiveValue, + place: Place, + fallthrough: BlockId, + id: EvaluationOrder, +} + +/// Helper to get loc from a terminal +fn terminal_loc(terminal: &Terminal) -> &Option { + match terminal { + Terminal::If { loc, .. } + | Terminal::Branch { loc, .. } + | Terminal::Logical { loc, .. } + | Terminal::Ternary { loc, .. } + | Terminal::Optional { loc, .. } + | Terminal::Throw { loc, .. } + | Terminal::Return { loc, .. } + | Terminal::Goto { loc, .. } + | Terminal::Switch { loc, .. } + | Terminal::DoWhile { loc, .. } + | Terminal::While { loc, .. } + | Terminal::For { loc, .. } + | Terminal::ForOf { loc, .. } + | Terminal::ForIn { loc, .. } + | Terminal::Label { loc, .. } + | Terminal::Sequence { loc, .. } + | Terminal::Unreachable { loc, .. } + | Terminal::Unsupported { loc, .. } + | Terminal::MaybeThrow { loc, .. } + | Terminal::Scope { loc, .. } + | Terminal::PrunedScope { loc, .. } + | Terminal::Try { loc, .. } => loc, + } +} diff --git a/crates/react_compiler_reactive_scopes/src/codegen_reactive_function.rs b/crates/react_compiler_reactive_scopes/src/codegen_reactive_function.rs new file mode 100644 index 000000000000..7134e43487f8 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/codegen_reactive_function.rs @@ -0,0 +1,3994 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Code generation pass: converts a `ReactiveFunction` tree back into a +//! Babel-compatible AST with memoization (useMemoCache) wired in. +//! +//! This is the final pass in the compilation pipeline. +//! +//! Corresponds to `src/ReactiveScopes/CodegenReactiveFunction.ts` in the TS +//! compiler. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_ast::{ + common::{BaseNode, Position as AstPosition, SourceLocation as AstSourceLocation}, + expressions::{self as ast_expr, ArrowFunctionBody, Expression, Identifier as AstIdentifier}, + jsx::{ + JSXAttribute as AstJSXAttribute, JSXAttributeItem, JSXAttributeName, JSXAttributeValue, + JSXChild, JSXClosingElement, JSXClosingFragment, JSXElement, JSXElementName, + JSXExpressionContainer, JSXExpressionContainerExpr, JSXFragment, JSXIdentifier, + JSXMemberExprObject, JSXMemberExpression, JSXNamespacedName, JSXOpeningElement, + JSXOpeningFragment, JSXSpreadAttribute, JSXText, + }, + literals::{ + BooleanLiteral, NullLiteral, NumericLiteral, RegExpLiteral as AstRegExpLiteral, + StringLiteral, TemplateElement, TemplateElementValue, + }, + operators::{ + AssignmentOperator, BinaryOperator as AstBinaryOperator, + LogicalOperator as AstLogicalOperator, UnaryOperator as AstUnaryOperator, + UpdateOperator as AstUpdateOperator, + }, + patterns::{ + ArrayPattern as AstArrayPattern, ObjectPatternProp, ObjectPatternProperty, PatternLike, + RestElement, + }, + statements::{ + BlockStatement, BreakStatement, CatchClause, ContinueStatement, DebuggerStatement, + Directive, DirectiveLiteral, DoWhileStatement, EmptyStatement, ExpressionStatement, + ForInStatement, ForInit, ForOfStatement, ForStatement, FunctionDeclaration, IfStatement, + LabeledStatement, ReturnStatement, Statement, SwitchCase, SwitchStatement, ThrowStatement, + TryStatement, VariableDeclaration, VariableDeclarationKind, VariableDeclarator, + WhileStatement, + }, +}; +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, CompilerErrorDetail, + ErrorCategory, SourceLocation as DiagSourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, + reactive::{ + PrunedReactiveScopeBlock, ReactiveBlock, ReactiveFunction, ReactiveInstruction, + ReactiveScopeBlock, ReactiveStatement, ReactiveTerminal, ReactiveTerminalTargetKind, + ReactiveValue, + }, + ArrayElement, ArrayPattern, BlockId, DeclarationId, FunctionExpressionType, IdentifierId, + InstructionKind, InstructionValue, JsxAttribute, JsxTag, LogicalOperator, ObjectPattern, + ObjectPropertyKey, ObjectPropertyOrSpread, ObjectPropertyType, ParamPattern, Pattern, Place, + PlaceOrSpread, PrimitiveValue, PropertyLiteral, ScopeId, SpreadPattern, +}; + +use crate::{ + build_reactive_function::build_reactive_function, + prune_hoisted_contexts::prune_hoisted_contexts, + prune_unused_labels::prune_unused_labels, + prune_unused_lvalues::prune_unused_lvalues, + rename_variables::rename_variables, + visitors::{visit_reactive_function, ReactiveFunctionVisitor}, +}; + +// ============================================================================= +// Public API +// ============================================================================= + +pub const MEMO_CACHE_SENTINEL: &str = "react.memo_cache_sentinel"; +pub const EARLY_RETURN_SENTINEL: &str = "react.early_return_sentinel"; + +/// FBT tags whose children get special codegen treatment. +const SINGLE_CHILD_FBT_TAGS: &[&str] = &["fbt:param", "fbs:param"]; + +/// Result of code generation for a single function. +pub struct CodegenFunction { + pub loc: Option, + pub id: Option, + pub name_hint: Option, + pub params: Vec, + pub body: BlockStatement, + pub generator: bool, + pub is_async: bool, + pub memo_slots_used: u32, + pub memo_blocks: u32, + pub memo_values: u32, + pub pruned_memo_blocks: u32, + pub pruned_memo_values: u32, + pub outlined: Vec, +} + +impl std::fmt::Debug for CodegenFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CodegenFunction") + .field("memo_slots_used", &self.memo_slots_used) + .field("memo_blocks", &self.memo_blocks) + .field("memo_values", &self.memo_values) + .field("pruned_memo_blocks", &self.pruned_memo_blocks) + .field("pruned_memo_values", &self.pruned_memo_values) + .finish() + } +} + +/// An outlined function extracted during compilation. +pub struct OutlinedFunction { + pub func: CodegenFunction, + pub fn_type: Option, +} + +/// Top-level entry point: generates code for a reactive function. +pub fn codegen_function( + func: &ReactiveFunction, + env: &mut Environment, + unique_identifiers: HashSet, + fbt_operands: HashSet, +) -> Result { + let fn_name = func.id.as_deref().unwrap_or("[[ anonymous ]]"); + let mut cx = Context::new(env, fn_name.to_string(), unique_identifiers, fbt_operands); + + // Fast Refresh: compute source hash and reserve a cache slot if enabled + let fast_refresh_state: Option<(u32, String)> = + if cx.env.config.enable_reset_cache_on_source_file_changes == Some(true) { + if let Some(ref code) = cx.env.code { + use hmac::{Hmac, Mac}; + use sha2::Sha256; + type HmacSha256 = Hmac; + // Match TS: createHmac('sha256', code).digest('hex') + // Node's createHmac uses the code as the HMAC key and hashes empty data. + let mac = HmacSha256::new_from_slice(code.as_bytes()) + .expect("HMAC can take key of any size"); + let hash = format!("{:x}", mac.finalize().into_bytes()); + let cache_index = cx.alloc_cache_index(); // Reserve slot 0 for the hash check + Some((cache_index, hash)) + } else { + None + } + } else { + None + }; + + let mut compiled = codegen_reactive_function(&mut cx, func)?; + + // enableEmitHookGuards: wrap entire function body in try/finally with + // $dispatcherGuard(PushHookGuard=0) / $dispatcherGuard(PopHookGuard=1). + // Per-hook-call wrapping is done inline during codegen + // (CallExpression/MethodCall). + if cx.env.hook_guard_name.is_some() + && cx.env.output_mode == react_compiler_hir::environment::OutputMode::Client + { + let guard_name = cx.env.hook_guard_name.as_ref().unwrap().clone(); + let body_stmts = std::mem::replace(&mut compiled.body.body, Vec::new()); + compiled.body.body = vec![create_function_body_hook_guard( + &guard_name, + body_stmts, + 0, + 1, + )]; + } + + let cache_count = compiled.memo_slots_used; + if cache_count != 0 { + let mut preface: Vec = Vec::new(); + let cache_name = cx.synthesize_name("$"); + + // const $ = useMemoCache(N) + preface.push(Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::Identifier(make_identifier(&cache_name)), + init: Some(Box::new(Expression::CallExpression( + ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(make_identifier("useMemoCache"))), + arguments: vec![Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: cache_count as f64, + })], + type_parameters: None, + type_arguments: None, + optional: None, + }, + ))), + definite: None, + }], + kind: VariableDeclarationKind::Const, + declare: None, + })); + + // Fast Refresh: emit cache invalidation check after useMemoCache + if let Some((cache_index, ref hash)) = fast_refresh_state { + let index_var = cx.synthesize_name("$i"); + // if ($[cacheIndex] !== "hash") { for (let $i = 0; $i < N; $i += 1) { $[$i] = + // Symbol.for("react.memo_cache_sentinel"); } $[cacheIndex] = "hash"; } + preface.push(Statement::IfStatement(IfStatement { + base: BaseNode::typed("IfStatement"), + test: Box::new(Expression::BinaryExpression(ast_expr::BinaryExpression { + base: BaseNode::typed("BinaryExpression"), + operator: AstBinaryOperator::StrictNeq, + left: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier(&cache_name))), + property: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: cache_index as f64, + })), + computed: true, + })), + right: Box::new(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: hash.clone(), + })), + })), + consequent: Box::new(Statement::BlockStatement(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![ + // for (let $i = 0; $i < N; $i += 1) { $[$i] = + // Symbol.for("react.memo_cache_sentinel"); } + Statement::ForStatement(ForStatement { + base: BaseNode::typed("ForStatement"), + init: Some(Box::new(ForInit::VariableDeclaration( + VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: PatternLike::Identifier(make_identifier(&index_var)), + init: Some(Box::new(Expression::NumericLiteral( + NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: 0.0, + }, + ))), + definite: None, + }], + kind: VariableDeclarationKind::Let, + declare: None, + }, + ))), + test: Some(Box::new(Expression::BinaryExpression( + ast_expr::BinaryExpression { + base: BaseNode::typed("BinaryExpression"), + operator: AstBinaryOperator::Lt, + left: Box::new(Expression::Identifier(make_identifier( + &index_var, + ))), + right: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: cache_count as f64, + })), + }, + ))), + update: Some(Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::AddAssign, + left: Box::new(PatternLike::Identifier(make_identifier( + &index_var, + ))), + right: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: 1.0, + })), + }, + ))), + body: Box::new(Statement::BlockStatement(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::MemberExpression( + ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier( + make_identifier(&cache_name), + )), + property: Box::new(Expression::Identifier( + make_identifier(&index_var), + )), + computed: true, + }, + )), + right: Box::new(Expression::CallExpression( + ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::MemberExpression( + ast_expr::MemberExpression { + base: BaseNode::typed( + "MemberExpression", + ), + object: Box::new( + Expression::Identifier( + make_identifier("Symbol"), + ), + ), + property: Box::new( + Expression::Identifier( + make_identifier("for"), + ), + ), + computed: false, + }, + )), + arguments: vec![Expression::StringLiteral( + StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: MEMO_CACHE_SENTINEL.to_string(), + }, + )], + type_parameters: None, + type_arguments: None, + optional: None, + }, + )), + }, + )), + })], + directives: Vec::new(), + })), + }), + // $[cacheIndex] = "hash" + Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::MemberExpression( + ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier( + make_identifier(&cache_name), + )), + property: Box::new(Expression::NumericLiteral( + NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: cache_index as f64, + }, + )), + computed: true, + }, + )), + right: Box::new(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: hash.clone(), + })), + }, + )), + }), + ], + directives: Vec::new(), + })), + alternate: None, + })); + } + + // Insert preface at the beginning of the body + let mut new_body = preface; + new_body.append(&mut compiled.body.body); + compiled.body.body = new_body; + } + + // Instrument forget: emit instrumentation call at the top of the function body + let emit_instrument_forget = cx.env.config.enable_emit_instrument_forget.clone(); + if let Some(ref instrument_config) = emit_instrument_forget { + if func.id.is_some() + && cx.env.output_mode == react_compiler_hir::environment::OutputMode::Client + { + // Use pre-resolved import names from environment (set by program-level code) + let instrument_fn_local = cx + .env + .instrument_fn_name + .clone() + .unwrap_or_else(|| instrument_config.fn_.import_specifier_name.clone()); + let instrument_gating_local = cx.env.instrument_gating_name.clone(); + + // Build the gating condition + let gating_expr: Option = + instrument_gating_local.map(|name| Expression::Identifier(make_identifier(&name))); + let global_gating_expr: Option = instrument_config + .global_gating + .as_ref() + .map(|g| Expression::Identifier(make_identifier(g))); + + let if_test = match (gating_expr, global_gating_expr) { + (Some(gating), Some(global)) => { + Expression::LogicalExpression(ast_expr::LogicalExpression { + base: BaseNode::typed("LogicalExpression"), + operator: AstLogicalOperator::And, + left: Box::new(global), + right: Box::new(gating), + }) + } + (Some(gating), None) => gating, + (None, Some(global)) => global, + (None, None) => unreachable!( + "InstrumentationConfig requires at least one of gating or globalGating" + ), + }; + + let fn_name_str = func.id.as_deref().unwrap_or(""); + let filename_str = cx.env.filename.as_deref().unwrap_or(""); + + let instrument_call = Statement::IfStatement(IfStatement { + base: BaseNode::typed("IfStatement"), + test: Box::new(if_test), + consequent: Box::new(Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(make_identifier( + &instrument_fn_local, + ))), + arguments: vec![ + Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: fn_name_str.to_string(), + }), + Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: filename_str.to_string(), + }), + ], + type_parameters: None, + type_arguments: None, + optional: None, + })), + })), + alternate: None, + }); + compiled.body.body.insert(0, instrument_call); + } + } + + // Process outlined functions + let outlined_entries = cx.env.take_outlined_functions(); + let mut outlined: Vec = Vec::new(); + for entry in outlined_entries { + let reactive_fn = build_reactive_function(&entry.func, cx.env)?; + let mut reactive_fn_mut = reactive_fn; + prune_unused_labels(&mut reactive_fn_mut, cx.env)?; + prune_unused_lvalues(&mut reactive_fn_mut, cx.env); + prune_hoisted_contexts(&mut reactive_fn_mut, cx.env)?; + + let identifiers = rename_variables(&mut reactive_fn_mut, cx.env); + let mut outlined_cx = Context::new( + cx.env, + reactive_fn_mut + .id + .as_deref() + .unwrap_or("[[ anonymous ]]") + .to_string(), + identifiers, + cx.fbt_operands.clone(), + ); + let codegen = codegen_reactive_function(&mut outlined_cx, &reactive_fn_mut)?; + outlined.push(OutlinedFunction { + func: codegen, + fn_type: entry.fn_type, + }); + } + compiled.outlined = outlined; + + Ok(compiled) +} + +// ============================================================================= +// Context +// ============================================================================= + +type Temporaries = HashMap>; + +#[derive(Clone)] +enum ExpressionOrJsxText { + Expression(Expression), + JsxText(JSXText), +} + +struct Context<'env> { + env: &'env mut Environment, + #[allow(dead_code)] + fn_name: String, + next_cache_index: u32, + declarations: HashSet, + temp: Temporaries, + object_methods: HashMap< + IdentifierId, + ( + InstructionValue, + Option, + ), + >, + unique_identifiers: HashSet, + fbt_operands: HashSet, + synthesized_names: HashMap, +} + +impl<'env> Context<'env> { + fn new( + env: &'env mut Environment, + fn_name: String, + unique_identifiers: HashSet, + fbt_operands: HashSet, + ) -> Self { + Context { + env, + fn_name, + next_cache_index: 0, + declarations: HashSet::new(), + temp: HashMap::new(), + object_methods: HashMap::new(), + unique_identifiers, + fbt_operands, + synthesized_names: HashMap::new(), + } + } + + fn alloc_cache_index(&mut self) -> u32 { + let idx = self.next_cache_index; + self.next_cache_index += 1; + idx + } + + fn declare(&mut self, identifier_id: IdentifierId) { + let ident = &self.env.identifiers[identifier_id.0 as usize]; + self.declarations.insert(ident.declaration_id); + } + + fn has_declared(&self, identifier_id: IdentifierId) -> bool { + let ident = &self.env.identifiers[identifier_id.0 as usize]; + self.declarations.contains(&ident.declaration_id) + } + + fn synthesize_name(&mut self, name: &str) -> String { + if let Some(prev) = self.synthesized_names.get(name) { + return prev.clone(); + } + let mut validated = name.to_string(); + let mut index = 0u32; + while self.unique_identifiers.contains(&validated) { + validated = format!("{name}{index}"); + index += 1; + } + self.unique_identifiers.insert(validated.clone()); + self.synthesized_names + .insert(name.to_string(), validated.clone()); + validated + } + + fn record_error(&mut self, detail: CompilerErrorDetail) -> Result<(), CompilerError> { + self.env.record_error(detail) + } +} + +// ============================================================================= +// Core codegen functions +// ============================================================================= + +fn codegen_reactive_function( + cx: &mut Context, + func: &ReactiveFunction, +) -> Result { + // Register parameters + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(sp) => &sp.place, + }; + let ident = &cx.env.identifiers[place.identifier.0 as usize]; + cx.temp.insert(ident.declaration_id, None); + cx.declare(place.identifier); + } + + let params: Vec = func + .params + .iter() + .map(|p| convert_parameter(p, cx.env)) + .collect::>()?; + let mut body = codegen_block(cx, &func.body)?; + + // Add directives + body.directives = func + .directives + .iter() + .map(|d| Directive { + base: BaseNode::typed("Directive"), + value: DirectiveLiteral { + base: BaseNode::typed("DirectiveLiteral"), + value: d.clone(), + }, + }) + .collect(); + + // Remove trailing `return undefined` + if let Some(last) = body.body.last() { + if matches!(last, Statement::ReturnStatement(ret) if ret.argument.is_none()) { + body.body.pop(); + } + } + + // Count memo blocks + let (memo_blocks, memo_values, pruned_memo_blocks, pruned_memo_values) = + count_memo_blocks(func, cx.env); + + Ok(CodegenFunction { + loc: func.loc, + id: func.id.as_ref().map(|name| make_identifier(name)), + name_hint: func.name_hint.clone(), + params, + body, + generator: func.generator, + is_async: func.is_async, + memo_slots_used: cx.next_cache_index, + memo_blocks, + memo_values, + pruned_memo_blocks, + pruned_memo_values, + outlined: Vec::new(), + }) +} + +fn convert_parameter( + param: &ParamPattern, + env: &Environment, +) -> Result { + match param { + ParamPattern::Place(place) => Ok(PatternLike::Identifier(convert_identifier( + place.identifier, + env, + )?)), + ParamPattern::Spread(spread) => Ok(PatternLike::RestElement(RestElement { + base: BaseNode::typed("RestElement"), + argument: Box::new(PatternLike::Identifier(convert_identifier( + spread.place.identifier, + env, + )?)), + type_annotation: None, + decorators: None, + })), + } +} + +// ============================================================================= +// Block codegen +// ============================================================================= + +fn codegen_block(cx: &mut Context, block: &ReactiveBlock) -> Result { + let temp_snapshot: Temporaries = cx.temp.clone(); + let result = codegen_block_no_reset(cx, block)?; + cx.temp = temp_snapshot; + Ok(result) +} + +fn codegen_block_no_reset( + cx: &mut Context, + block: &ReactiveBlock, +) -> Result { + let mut statements: Vec = Vec::new(); + for item in block { + match item { + ReactiveStatement::Instruction(instr) => { + if let Some(stmt) = codegen_instruction_nullable(cx, instr)? { + statements.push(stmt); + } + } + ReactiveStatement::PrunedScope(PrunedReactiveScopeBlock { instructions, .. }) => { + let scope_block = codegen_block_no_reset(cx, instructions)?; + statements.extend(scope_block.body); + } + ReactiveStatement::Scope(ReactiveScopeBlock { + scope, + instructions, + }) => { + let temp_snapshot = cx.temp.clone(); + codegen_reactive_scope(cx, &mut statements, *scope, instructions)?; + cx.temp = temp_snapshot; + } + ReactiveStatement::Terminal(term_stmt) => { + let stmt = codegen_terminal(cx, &term_stmt.terminal)?; + let Some(stmt) = stmt else { + continue; + }; + if let Some(ref label) = term_stmt.label { + if !label.implicit { + let inner = if let Statement::BlockStatement(bs) = &stmt { + if bs.body.len() == 1 { + bs.body[0].clone() + } else { + stmt + } + } else { + stmt + }; + statements.push(Statement::LabeledStatement(LabeledStatement { + base: BaseNode::typed("LabeledStatement"), + label: make_identifier(&codegen_label(label.id)), + body: Box::new(inner), + })); + } else if let Statement::BlockStatement(bs) = stmt { + statements.extend(bs.body); + } else { + statements.push(stmt); + } + } else if let Statement::BlockStatement(bs) = stmt { + statements.extend(bs.body); + } else { + statements.push(stmt); + } + } + } + } + Ok(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: statements, + directives: Vec::new(), + }) +} + +// ============================================================================= +// Reactive scope codegen (memoization) +// ============================================================================= + +fn codegen_reactive_scope( + cx: &mut Context, + statements: &mut Vec, + scope_id: ScopeId, + block: &ReactiveBlock, +) -> Result<(), CompilerError> { + // Clone scope data upfront to avoid holding a borrow on cx.env + let scope_deps = cx.env.scopes[scope_id.0 as usize].dependencies.clone(); + let scope_decls = cx.env.scopes[scope_id.0 as usize].declarations.clone(); + let scope_reassignments = cx.env.scopes[scope_id.0 as usize].reassignments.clone(); + + let mut cache_store_stmts: Vec = Vec::new(); + let mut cache_load_stmts: Vec = Vec::new(); + let mut cache_loads: Vec<(AstIdentifier, u32, Expression)> = Vec::new(); + let mut change_exprs: Vec = Vec::new(); + + // Sort dependencies + let mut deps = scope_deps; + deps.sort_by(|a, b| compare_scope_dependency(a, b, cx.env)); + + for dep in &deps { + let index = cx.alloc_cache_index(); + let cache_name = cx.synthesize_name("$"); + let comparison = Expression::BinaryExpression(ast_expr::BinaryExpression { + base: BaseNode::typed("BinaryExpression"), + operator: AstBinaryOperator::StrictNeq, + left: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier(&cache_name))), + property: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: index as f64, + })), + computed: true, + })), + right: Box::new(codegen_dependency(cx, dep)?), + }); + change_exprs.push(comparison); + + // Store dependency value into cache + let dep_value = codegen_dependency(cx, dep)?; + cache_store_stmts.push(Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier(&cache_name))), + property: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: index as f64, + })), + computed: true, + })), + right: Box::new(dep_value), + }, + )), + })); + } + + let mut first_output_index: Option = None; + + // Sort declarations + let mut decls = scope_decls; + decls.sort_by(|(_id_a, a), (_id_b, b)| compare_scope_declaration(a, b, cx.env)); + + for (_ident_id, decl) in &decls { + let index = cx.alloc_cache_index(); + if first_output_index.is_none() { + first_output_index = Some(index); + } + + let ident = &cx.env.identifiers[decl.identifier.0 as usize]; + invariant( + ident.name.is_some(), + &format!( + "Expected scope declaration identifier to be named, id={}", + decl.identifier.0 + ), + None, + )?; + + let name = convert_identifier(decl.identifier, cx.env)?; + if !cx.has_declared(decl.identifier) { + statements.push(Statement::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + declarations: vec![make_var_declarator( + PatternLike::Identifier(name.clone()), + None, + )], + kind: VariableDeclarationKind::Let, + declare: None, + })); + } + cache_loads.push((name.clone(), index, Expression::Identifier(name.clone()))); + cx.declare(decl.identifier); + } + + for reassignment_id in scope_reassignments { + let index = cx.alloc_cache_index(); + if first_output_index.is_none() { + first_output_index = Some(index); + } + let name = convert_identifier(reassignment_id, cx.env)?; + cache_loads.push((name.clone(), index, Expression::Identifier(name))); + } + + // Build test condition + let test_condition = if change_exprs.is_empty() { + let first_idx = first_output_index.ok_or_else(|| { + invariant_err("Expected scope to have at least one declaration", None) + })?; + let cache_name = cx.synthesize_name("$"); + Expression::BinaryExpression(ast_expr::BinaryExpression { + base: BaseNode::typed("BinaryExpression"), + operator: AstBinaryOperator::StrictEq, + left: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier(&cache_name))), + property: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: first_idx as f64, + })), + computed: true, + })), + right: Box::new(symbol_for(MEMO_CACHE_SENTINEL)), + }) + } else { + change_exprs + .into_iter() + .reduce(|acc, expr| { + Expression::LogicalExpression(ast_expr::LogicalExpression { + base: BaseNode::typed("LogicalExpression"), + operator: AstLogicalOperator::Or, + left: Box::new(acc), + right: Box::new(expr), + }) + }) + .unwrap() + }; + + let mut computation_block = codegen_block(cx, block)?; + + // Build cache store and load statements for declarations + for (name, index, value) in &cache_loads { + let cache_name = cx.synthesize_name("$"); + cache_store_stmts.push(Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier(&cache_name))), + property: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: *index as f64, + })), + computed: true, + })), + right: Box::new(value.clone()), + }, + )), + })); + cache_load_stmts.push(Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::Identifier(name.clone())), + right: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier(&cache_name))), + property: Box::new(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: *index as f64, + })), + computed: true, + })), + }, + )), + })); + } + + computation_block.body.extend(cache_store_stmts); + + let memo_stmt = Statement::IfStatement(IfStatement { + base: BaseNode::typed("IfStatement"), + test: Box::new(test_condition), + consequent: Box::new(Statement::BlockStatement(computation_block)), + alternate: Some(Box::new(Statement::BlockStatement(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: cache_load_stmts, + directives: Vec::new(), + }))), + }); + statements.push(memo_stmt); + + // Handle early return + let early_return_value = cx.env.scopes[scope_id.0 as usize] + .early_return_value + .clone(); + if let Some(ref early_return) = early_return_value { + let early_ident = &cx.env.identifiers[early_return.value.0 as usize]; + let name = match &early_ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => n.clone(), + Some(react_compiler_hir::IdentifierName::Promoted(n)) => n.clone(), + None => { + return Err(invariant_err( + "Expected early return value to be promoted to a named variable", + early_return.loc, + )); + } + }; + statements.push(Statement::IfStatement(IfStatement { + base: BaseNode::typed("IfStatement"), + test: Box::new(Expression::BinaryExpression(ast_expr::BinaryExpression { + base: BaseNode::typed("BinaryExpression"), + operator: AstBinaryOperator::StrictNeq, + left: Box::new(Expression::Identifier(make_identifier(&name))), + right: Box::new(symbol_for(EARLY_RETURN_SENTINEL)), + })), + consequent: Box::new(Statement::BlockStatement(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![Statement::ReturnStatement(ReturnStatement { + base: BaseNode::typed("ReturnStatement"), + argument: Some(Box::new(Expression::Identifier(make_identifier(&name)))), + })], + directives: Vec::new(), + })), + alternate: None, + })); + } + + Ok(()) +} + +// ============================================================================= +// Terminal codegen +// ============================================================================= + +fn codegen_terminal( + cx: &mut Context, + terminal: &ReactiveTerminal, +) -> Result, CompilerError> { + match terminal { + ReactiveTerminal::Break { + target, + target_kind, + loc, + .. + } => { + if *target_kind == ReactiveTerminalTargetKind::Implicit { + return Ok(None); + } + Ok(Some(Statement::BreakStatement(BreakStatement { + base: base_node_with_loc("BreakStatement", *loc), + label: if *target_kind == ReactiveTerminalTargetKind::Labeled { + Some(make_identifier(&codegen_label(*target))) + } else { + None + }, + }))) + } + ReactiveTerminal::Continue { + target, + target_kind, + loc, + .. + } => { + if *target_kind == ReactiveTerminalTargetKind::Implicit { + return Ok(None); + } + Ok(Some(Statement::ContinueStatement(ContinueStatement { + base: base_node_with_loc("ContinueStatement", *loc), + label: if *target_kind == ReactiveTerminalTargetKind::Labeled { + Some(make_identifier(&codegen_label(*target))) + } else { + None + }, + }))) + } + ReactiveTerminal::Return { value, loc, .. } => { + let expr = codegen_place_to_expression(cx, value)?; + if let Expression::Identifier(ref ident) = expr { + if ident.name == "undefined" { + return Ok(Some(Statement::ReturnStatement(ReturnStatement { + base: base_node_with_loc("ReturnStatement", *loc), + argument: None, + }))); + } + } + Ok(Some(Statement::ReturnStatement(ReturnStatement { + base: base_node_with_loc("ReturnStatement", *loc), + argument: Some(Box::new(expr)), + }))) + } + ReactiveTerminal::Throw { value, loc, .. } => { + let expr = codegen_place_to_expression(cx, value)?; + Ok(Some(Statement::ThrowStatement(ThrowStatement { + base: base_node_with_loc("ThrowStatement", *loc), + argument: Box::new(expr), + }))) + } + ReactiveTerminal::If { + test, + consequent, + alternate, + loc, + .. + } => { + let test_expr = codegen_place_to_expression(cx, test)?; + let consequent_block = codegen_block(cx, consequent)?; + let alternate_stmt = if let Some(alt) = alternate { + let block = codegen_block(cx, alt)?; + if block.body.is_empty() { + None + } else { + Some(Box::new(Statement::BlockStatement(block))) + } + } else { + None + }; + Ok(Some(Statement::IfStatement(IfStatement { + base: base_node_with_loc("IfStatement", *loc), + test: Box::new(test_expr), + consequent: Box::new(Statement::BlockStatement(consequent_block)), + alternate: alternate_stmt, + }))) + } + ReactiveTerminal::Switch { + test, cases, loc, .. + } => { + let test_expr = codegen_place_to_expression(cx, test)?; + let switch_cases: Vec = cases + .iter() + .map(|case| { + let test = case + .test + .as_ref() + .map(|t| codegen_place_to_expression(cx, t)) + .transpose()?; + let block = case + .block + .as_ref() + .map(|b| codegen_block(cx, b)) + .transpose()?; + let consequent = match block { + Some(b) if b.body.is_empty() => Vec::new(), + Some(b) => vec![Statement::BlockStatement(b)], + None => Vec::new(), + }; + Ok(SwitchCase { + base: BaseNode::typed("SwitchCase"), + test: test.map(Box::new), + consequent, + }) + }) + .collect::>()?; + Ok(Some(Statement::SwitchStatement(SwitchStatement { + base: base_node_with_loc("SwitchStatement", *loc), + discriminant: Box::new(test_expr), + cases: switch_cases, + }))) + } + ReactiveTerminal::DoWhile { + loop_block, + test, + loc, + .. + } => { + let test_expr = codegen_instruction_value_to_expression(cx, test)?; + let body = codegen_block(cx, loop_block)?; + Ok(Some(Statement::DoWhileStatement(DoWhileStatement { + base: base_node_with_loc("DoWhileStatement", *loc), + test: Box::new(test_expr), + body: Box::new(Statement::BlockStatement(body)), + }))) + } + ReactiveTerminal::While { + test, + loop_block, + loc, + .. + } => { + let test_expr = codegen_instruction_value_to_expression(cx, test)?; + let body = codegen_block(cx, loop_block)?; + Ok(Some(Statement::WhileStatement(WhileStatement { + base: base_node_with_loc("WhileStatement", *loc), + test: Box::new(test_expr), + body: Box::new(Statement::BlockStatement(body)), + }))) + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + loc, + .. + } => { + let init_val = codegen_for_init(cx, init)?; + let test_expr = codegen_instruction_value_to_expression(cx, test)?; + let update_expr = update + .as_ref() + .map(|u| codegen_instruction_value_to_expression(cx, u)) + .transpose()?; + let body = codegen_block(cx, loop_block)?; + Ok(Some(Statement::ForStatement(ForStatement { + base: base_node_with_loc("ForStatement", *loc), + init: init_val.map(|v| Box::new(v)), + test: Some(Box::new(test_expr)), + update: update_expr.map(Box::new), + body: Box::new(Statement::BlockStatement(body)), + }))) + } + ReactiveTerminal::ForIn { + init, + loop_block, + loc, + .. + } => codegen_for_in(cx, init, loop_block, *loc), + ReactiveTerminal::ForOf { + init, + test, + loop_block, + loc, + .. + } => codegen_for_of(cx, init, test, loop_block, *loc), + ReactiveTerminal::Label { block, .. } => { + let body = codegen_block(cx, block)?; + Ok(Some(Statement::BlockStatement(body))) + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + loc, + .. + } => { + let catch_param = match handler_binding.as_ref() { + Some(binding) => { + let ident = &cx.env.identifiers[binding.identifier.0 as usize]; + cx.temp.insert(ident.declaration_id, None); + Some(PatternLike::Identifier(convert_identifier( + binding.identifier, + cx.env, + )?)) + } + None => None, + }; + let try_block = codegen_block(cx, block)?; + let handler_block = codegen_block(cx, handler)?; + Ok(Some(Statement::TryStatement(TryStatement { + base: base_node_with_loc("TryStatement", *loc), + block: try_block, + handler: Some(CatchClause { + base: BaseNode::typed("CatchClause"), + param: catch_param, + body: handler_block, + }), + finalizer: None, + }))) + } + } +} + +fn codegen_for_in( + cx: &mut Context, + init: &ReactiveValue, + loop_block: &ReactiveBlock, + loc: Option, +) -> Result, CompilerError> { + let ReactiveValue::SequenceExpression { instructions, .. } = init else { + return Err(invariant_err( + "Expected a sequence expression init for for..in", + None, + )); + }; + if instructions.len() != 2 { + cx.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Support non-trivial for..in inits".to_string(), + description: None, + loc, + suggestions: None, + })?; + return Ok(Some(Statement::EmptyStatement(EmptyStatement { + base: BaseNode::typed("EmptyStatement"), + }))); + } + let iterable_collection = &instructions[0]; + let iterable_item = &instructions[1]; + let instr_value = get_instruction_value(&iterable_item.value)?; + let (lval, var_decl_kind) = extract_for_in_of_lval(cx, instr_value, "for..in", loc)?; + let right = codegen_instruction_value_to_expression(cx, &iterable_collection.value)?; + let body = codegen_block(cx, loop_block)?; + Ok(Some(Statement::ForInStatement(ForInStatement { + base: base_node_with_loc("ForInStatement", loc), + left: Box::new( + react_compiler_ast::statements::ForInOfLeft::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: lval, + init: None, + definite: None, + }], + kind: var_decl_kind, + declare: None, + }), + ), + right: Box::new(right), + body: Box::new(Statement::BlockStatement(body)), + }))) +} + +fn codegen_for_of( + cx: &mut Context, + init: &ReactiveValue, + test: &ReactiveValue, + loop_block: &ReactiveBlock, + loc: Option, +) -> Result, CompilerError> { + // Validate init is SequenceExpression with single GetIterator instruction + let ReactiveValue::SequenceExpression { + instructions: init_instrs, + .. + } = init + else { + return Err(invariant_err( + "Expected a sequence expression init for for..of", + None, + )); + }; + if init_instrs.len() != 1 { + return Err(invariant_err( + "Expected a single-expression sequence expression init for for..of", + None, + )); + } + let get_iter_value = get_instruction_value(&init_instrs[0].value)?; + let InstructionValue::GetIterator { collection, .. } = get_iter_value else { + return Err(invariant_err("Expected GetIterator in for..of init", None)); + }; + + let ReactiveValue::SequenceExpression { + instructions: test_instrs, + .. + } = test + else { + return Err(invariant_err( + "Expected a sequence expression test for for..of", + None, + )); + }; + if test_instrs.len() != 2 { + cx.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: "Support non-trivial for..of inits".to_string(), + description: None, + loc, + suggestions: None, + })?; + return Ok(Some(Statement::EmptyStatement(EmptyStatement { + base: BaseNode::typed("EmptyStatement"), + }))); + } + let iterable_item = &test_instrs[1]; + let instr_value = get_instruction_value(&iterable_item.value)?; + let (lval, var_decl_kind) = extract_for_in_of_lval(cx, instr_value, "for..of", loc)?; + + let right = codegen_place_to_expression(cx, collection)?; + let body = codegen_block(cx, loop_block)?; + Ok(Some(Statement::ForOfStatement(ForOfStatement { + base: base_node_with_loc("ForOfStatement", loc), + left: Box::new( + react_compiler_ast::statements::ForInOfLeft::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + declarations: vec![VariableDeclarator { + base: BaseNode::typed("VariableDeclarator"), + id: lval, + init: None, + definite: None, + }], + kind: var_decl_kind, + declare: None, + }), + ), + right: Box::new(right), + body: Box::new(Statement::BlockStatement(body)), + is_await: false, + }))) +} + +/// Extract lval and declaration kind from a for-in/for-of iterable item +/// instruction. +fn extract_for_in_of_lval( + cx: &mut Context, + instr_value: &InstructionValue, + context_name: &str, + loc: Option, +) -> Result<(PatternLike, VariableDeclarationKind), CompilerError> { + let (lval, kind) = match instr_value { + InstructionValue::StoreLocal { lvalue, .. } => ( + codegen_lvalue(cx, &LvalueRef::Place(&lvalue.place))?, + lvalue.kind, + ), + InstructionValue::Destructure { lvalue, .. } => ( + codegen_lvalue(cx, &LvalueRef::Pattern(&lvalue.pattern))?, + lvalue.kind, + ), + InstructionValue::StoreContext { .. } => { + cx.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!("Support non-trivial {} inits", context_name), + description: None, + loc, + suggestions: None, + })?; + return Ok(( + PatternLike::Identifier(make_identifier("_")), + VariableDeclarationKind::Let, + )); + } + _ => { + return Err(invariant_err( + &format!( + "Expected a StoreLocal or Destructure in {} collection, found {:?}", + context_name, + std::mem::discriminant(instr_value) + ), + None, + )); + } + }; + let var_decl_kind = match kind { + InstructionKind::Const => VariableDeclarationKind::Const, + InstructionKind::Let => VariableDeclarationKind::Let, + _ => { + return Err(invariant_err( + &format!( + "Unexpected {:?} variable in {} collection", + kind, context_name + ), + None, + )); + } + }; + Ok((lval, var_decl_kind)) +} + +fn codegen_for_init( + cx: &mut Context, + init: &ReactiveValue, +) -> Result, CompilerError> { + if let ReactiveValue::SequenceExpression { instructions, .. } = init { + let block_items: Vec = instructions + .iter() + .map(|i| ReactiveStatement::Instruction(i.clone())) + .collect(); + let body = codegen_block(cx, &block_items)?.body; + let mut declarators: Vec = Vec::new(); + let mut kind = VariableDeclarationKind::Const; + for instr in body { + // Check if this is an assignment that can be folded into the last declarator + if let Statement::ExpressionStatement(ref expr_stmt) = instr { + if let Expression::AssignmentExpression(ref assign) = *expr_stmt.expression { + if matches!(assign.operator, AssignmentOperator::Assign) { + if let PatternLike::Identifier(ref left_ident) = *assign.left { + if let Some(top) = declarators.last_mut() { + if let PatternLike::Identifier(ref top_ident) = top.id { + if top_ident.name == left_ident.name && top.init.is_none() { + top.init = Some(assign.right.clone()); + continue; + } + } + } + } + } + } + } + + if let Statement::VariableDeclaration(var_decl) = instr { + match var_decl.kind { + VariableDeclarationKind::Let | VariableDeclarationKind::Const => {} + _ => { + return Err(invariant_err( + "Expected a let or const variable declaration", + None, + )); + } + } + if matches!(var_decl.kind, VariableDeclarationKind::Let) { + kind = VariableDeclarationKind::Let; + } + declarators.extend(var_decl.declarations); + } else { + let stmt_type = get_statement_type_name(&instr); + let stmt_loc = get_statement_loc(&instr); + let reason = "Expected a variable declaration".to_string(); + let mut err = CompilerError::new(); + err.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Invariant, + reason.clone(), + Some(format!("Got {}", stmt_type)), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: stmt_loc, + message: Some(reason), + identifier_name: None, + }), + ); + return Err(err); + } + } + if declarators.is_empty() { + return Err(invariant_err( + "Expected a variable declaration in for-init", + None, + )); + } + Ok(Some(ForInit::VariableDeclaration(VariableDeclaration { + base: BaseNode::typed("VariableDeclaration"), + declarations: declarators, + kind, + declare: None, + }))) + } else { + let expr = codegen_instruction_value_to_expression(cx, init)?; + Ok(Some(ForInit::Expression(Box::new(expr)))) + } +} + +// ============================================================================= +// Instruction codegen +// ============================================================================= + +fn codegen_instruction_nullable( + cx: &mut Context, + instr: &ReactiveInstruction, +) -> Result, CompilerError> { + // Only check specific InstructionValue kinds for the base Instruction variant + if let ReactiveValue::Instruction(ref value) = instr.value { + match value { + InstructionValue::StoreLocal { .. } + | InstructionValue::StoreContext { .. } + | InstructionValue::Destructure { .. } + | InstructionValue::DeclareLocal { .. } + | InstructionValue::DeclareContext { .. } => { + return codegen_store_or_declare(cx, instr, value); + } + InstructionValue::StartMemoize { .. } | InstructionValue::FinishMemoize { .. } => { + return Ok(None); + } + InstructionValue::Debugger { .. } => { + return Ok(Some(Statement::DebuggerStatement(DebuggerStatement { + base: base_node_with_loc("DebuggerStatement", instr.loc), + }))); + } + InstructionValue::UnsupportedNode { + original_node: Some(node), + .. + } => { + // We have the original AST node serialized as JSON; deserialize and emit it + // directly + let stmt: Statement = serde_json::from_value(node.clone()).map_err(|e| { + invariant_err( + &format!("Failed to deserialize original AST node: {}", e), + None, + ) + })?; + return Ok(Some(stmt)); + } + InstructionValue::ObjectMethod { loc, .. } => { + invariant( + instr.lvalue.is_some(), + "Expected object methods to have a temp lvalue", + None, + )?; + let lvalue = instr.lvalue.as_ref().unwrap(); + cx.object_methods + .insert(lvalue.identifier, (value.clone(), *loc)); + return Ok(None); + } + _ => {} // fall through to general codegen + } + } + // General case: codegen the full ReactiveValue + let expr_value = codegen_instruction_value(cx, &instr.value)?; + let stmt = codegen_instruction(cx, instr, expr_value)?; + if matches!(stmt, Statement::EmptyStatement(_)) { + Ok(None) + } else { + Ok(Some(stmt)) + } +} + +fn codegen_store_or_declare( + cx: &mut Context, + instr: &ReactiveInstruction, + value: &InstructionValue, +) -> Result, CompilerError> { + match value { + InstructionValue::StoreLocal { + lvalue, value: val, .. + } => { + let mut kind = lvalue.kind; + if cx.has_declared(lvalue.place.identifier) { + kind = InstructionKind::Reassign; + } + let rhs = codegen_place_to_expression(cx, val)?; + emit_store(cx, instr, kind, &LvalueRef::Place(&lvalue.place), Some(rhs)) + } + InstructionValue::StoreContext { + lvalue, value: val, .. + } => { + let rhs = codegen_place_to_expression(cx, val)?; + emit_store( + cx, + instr, + lvalue.kind, + &LvalueRef::Place(&lvalue.place), + Some(rhs), + ) + } + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::DeclareContext { lvalue, .. } => { + if cx.has_declared(lvalue.place.identifier) { + return Ok(None); + } + emit_store( + cx, + instr, + lvalue.kind, + &LvalueRef::Place(&lvalue.place), + None, + ) + } + InstructionValue::Destructure { + lvalue, value: val, .. + } => { + let kind = lvalue.kind; + // Register temporaries for unnamed pattern operands + for place in react_compiler_hir::visitors::each_pattern_operand(&lvalue.pattern) { + let ident = &cx.env.identifiers[place.identifier.0 as usize]; + if kind != InstructionKind::Reassign && ident.name.is_none() { + cx.temp.insert(ident.declaration_id, None); + } + } + let rhs = codegen_place_to_expression(cx, val)?; + emit_store( + cx, + instr, + kind, + &LvalueRef::Pattern(&lvalue.pattern), + Some(rhs), + ) + } + _ => unreachable!(), + } +} + +fn emit_store( + cx: &mut Context, + instr: &ReactiveInstruction, + kind: InstructionKind, + lvalue: &LvalueRef, + value: Option, +) -> Result, CompilerError> { + match kind { + InstructionKind::Const => { + // Invariant: Const declarations cannot also have an outer lvalue + // (i.e., cannot be referenced as an expression) + if instr.lvalue.is_some() { + return Err(invariant_err_with_detail_message( + "Const declaration cannot be referenced as an expression", + "this is Const", + instr.loc, + )); + } + let lval = codegen_lvalue(cx, lvalue)?; + Ok(Some(Statement::VariableDeclaration(VariableDeclaration { + base: base_node_with_loc("VariableDeclaration", instr.loc), + declarations: vec![make_var_declarator(lval, value)], + kind: VariableDeclarationKind::Const, + declare: None, + }))) + } + InstructionKind::Function => { + let lval = codegen_lvalue(cx, lvalue)?; + let PatternLike::Identifier(fn_id) = lval else { + return Err(invariant_err( + "Expected an identifier as function declaration lvalue", + None, + )); + }; + let Some(rhs) = value else { + return Err(invariant_err( + "Expected a function value for function declaration", + None, + )); + }; + match rhs { + Expression::FunctionExpression(func_expr) => { + Ok(Some(Statement::FunctionDeclaration(FunctionDeclaration { + base: base_node_with_loc("FunctionDeclaration", instr.loc), + id: Some(fn_id), + params: func_expr.params, + body: func_expr.body, + generator: func_expr.generator, + is_async: func_expr.is_async, + declare: None, + return_type: None, + type_parameters: None, + predicate: None, + component_declaration: false, + hook_declaration: false, + }))) + } + _ => Err(invariant_err( + "Expected a function expression for function declaration", + None, + )), + } + } + InstructionKind::Let => { + // Invariant: Let declarations cannot also have an outer lvalue + if instr.lvalue.is_some() { + return Err(invariant_err_with_detail_message( + "Const declaration cannot be referenced as an expression", + "this is Let", + instr.loc, + )); + } + let lval = codegen_lvalue(cx, lvalue)?; + Ok(Some(Statement::VariableDeclaration(VariableDeclaration { + base: base_node_with_loc("VariableDeclaration", instr.loc), + declarations: vec![make_var_declarator(lval, value)], + kind: VariableDeclarationKind::Let, + declare: None, + }))) + } + InstructionKind::Reassign => { + let Some(rhs) = value else { + return Err(invariant_err("Expected a value for reassignment", None)); + }; + let lval = codegen_lvalue(cx, lvalue)?; + let expr = Expression::AssignmentExpression(ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(lval), + right: Box::new(rhs), + }); + if let Some(ref lvalue_place) = instr.lvalue { + let is_store_context = matches!( + &instr.value, + ReactiveValue::Instruction(InstructionValue::StoreContext { .. }) + ); + if !is_store_context { + let ident = &cx.env.identifiers[lvalue_place.identifier.0 as usize]; + cx.temp.insert( + ident.declaration_id, + Some(ExpressionOrJsxText::Expression(expr)), + ); + return Ok(None); + } else { + let stmt = + codegen_instruction(cx, instr, ExpressionOrJsxText::Expression(expr))?; + if matches!(stmt, Statement::EmptyStatement(_)) { + return Ok(None); + } + return Ok(Some(stmt)); + } + } + Ok(Some(Statement::ExpressionStatement(ExpressionStatement { + base: base_node_with_loc("ExpressionStatement", instr.loc), + expression: Box::new(expr), + }))) + } + InstructionKind::Catch => Ok(Some(Statement::EmptyStatement(EmptyStatement { + base: BaseNode::typed("EmptyStatement"), + }))), + InstructionKind::HoistedLet + | InstructionKind::HoistedConst + | InstructionKind::HoistedFunction => Err(invariant_err( + &format!( + "Expected {:?} to have been pruned in PruneHoistedContexts", + kind + ), + None, + )), + } +} + +fn codegen_instruction( + cx: &mut Context, + instr: &ReactiveInstruction, + value: ExpressionOrJsxText, +) -> Result { + let Some(ref lvalue) = instr.lvalue else { + let expr = convert_value_to_expression(value); + return Ok(Statement::ExpressionStatement(ExpressionStatement { + base: base_node_with_loc("ExpressionStatement", instr.loc), + expression: Box::new(expr), + })); + }; + let ident = &cx.env.identifiers[lvalue.identifier.0 as usize]; + if ident.name.is_none() { + // temporary + cx.temp.insert(ident.declaration_id, Some(value)); + return Ok(Statement::EmptyStatement(EmptyStatement { + base: BaseNode::typed("EmptyStatement"), + })); + } + let expr_value = convert_value_to_expression(value); + if cx.has_declared(lvalue.identifier) { + Ok(Statement::ExpressionStatement(ExpressionStatement { + base: base_node_with_loc("ExpressionStatement", instr.loc), + expression: Box::new(Expression::AssignmentExpression( + ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::Identifier(convert_identifier( + lvalue.identifier, + cx.env, + )?)), + right: Box::new(expr_value), + }, + )), + })) + } else { + Ok(Statement::VariableDeclaration(VariableDeclaration { + base: base_node_with_loc("VariableDeclaration", instr.loc), + declarations: vec![make_var_declarator( + PatternLike::Identifier(convert_identifier(lvalue.identifier, cx.env)?), + Some(expr_value), + )], + kind: VariableDeclarationKind::Const, + declare: None, + })) + } +} + +// ============================================================================= +// Instruction value codegen +// ============================================================================= + +fn codegen_instruction_value_to_expression( + cx: &mut Context, + instr_value: &ReactiveValue, +) -> Result { + let value = codegen_instruction_value(cx, instr_value)?; + Ok(convert_value_to_expression(value)) +} + +fn codegen_instruction_value( + cx: &mut Context, + instr_value: &ReactiveValue, +) -> Result { + match instr_value { + ReactiveValue::Instruction(iv) => { + let mut result = codegen_base_instruction_value(cx, iv)?; + // Propagate instrValue.loc to the generated expression, matching TS: + // if (instrValue.loc != null && instrValue.loc != GeneratedSource) { + // value.loc = instrValue.loc; + // } + if let Some(loc) = iv.loc() { + apply_loc_to_value(&mut result, *loc); + } + Ok(result) + } + ReactiveValue::LogicalExpression { + operator, + left, + right, + .. + } => { + let left_expr = codegen_instruction_value_to_expression(cx, left)?; + let right_expr = codegen_instruction_value_to_expression(cx, right)?; + Ok(ExpressionOrJsxText::Expression( + Expression::LogicalExpression(ast_expr::LogicalExpression { + base: BaseNode::typed("LogicalExpression"), + operator: convert_logical_operator(operator), + left: Box::new(left_expr), + right: Box::new(right_expr), + }), + )) + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + let test_expr = codegen_instruction_value_to_expression(cx, test)?; + let cons_expr = codegen_instruction_value_to_expression(cx, consequent)?; + let alt_expr = codegen_instruction_value_to_expression(cx, alternate)?; + Ok(ExpressionOrJsxText::Expression( + Expression::ConditionalExpression(ast_expr::ConditionalExpression { + base: BaseNode::typed("ConditionalExpression"), + test: Box::new(test_expr), + consequent: Box::new(cons_expr), + alternate: Box::new(alt_expr), + }), + )) + } + ReactiveValue::SequenceExpression { + instructions, + value, + .. + } => { + let block_items: Vec = instructions + .iter() + .map(|i| ReactiveStatement::Instruction(i.clone())) + .collect(); + let body = codegen_block_no_reset(cx, &block_items)?.body; + let mut expressions: Vec = Vec::new(); + for stmt in body { + match stmt { + Statement::ExpressionStatement(es) => { + expressions.push(*es.expression); + } + Statement::VariableDeclaration(ref var_decl) => { + let _declarator = &var_decl.declarations[0]; + cx.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!( + "(CodegenReactiveFunction::codegenInstructionValue) Cannot \ + declare variables in a value block" + ), + description: None, + loc: None, + suggestions: None, + })?; + expressions.push(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: format!("TODO handle declaration"), + })); + } + _ => { + cx.record_error(CompilerErrorDetail { + category: ErrorCategory::Todo, + reason: format!( + "(CodegenReactiveFunction::codegenInstructionValue) Handle \ + conversion of statement to expression" + ), + description: None, + loc: None, + suggestions: None, + })?; + expressions.push(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: format!("TODO handle statement"), + })); + } + } + } + let final_expr = codegen_instruction_value_to_expression(cx, value)?; + if expressions.is_empty() { + Ok(ExpressionOrJsxText::Expression(final_expr)) + } else { + expressions.push(final_expr); + Ok(ExpressionOrJsxText::Expression( + Expression::SequenceExpression(ast_expr::SequenceExpression { + base: BaseNode::typed("SequenceExpression"), + expressions, + }), + )) + } + } + ReactiveValue::OptionalExpression { + value, optional, .. + } => { + let opt_value = codegen_instruction_value_to_expression(cx, value)?; + match opt_value { + Expression::OptionalCallExpression(oce) => Ok(ExpressionOrJsxText::Expression( + Expression::OptionalCallExpression(ast_expr::OptionalCallExpression { + base: BaseNode::typed("OptionalCallExpression"), + callee: oce.callee, + arguments: oce.arguments, + optional: *optional, + type_parameters: oce.type_parameters, + type_arguments: oce.type_arguments, + }), + )), + Expression::CallExpression(ce) => Ok(ExpressionOrJsxText::Expression( + Expression::OptionalCallExpression(ast_expr::OptionalCallExpression { + base: BaseNode::typed("OptionalCallExpression"), + callee: ce.callee, + arguments: ce.arguments, + optional: *optional, + type_parameters: None, + type_arguments: None, + }), + )), + Expression::OptionalMemberExpression(ome) => Ok(ExpressionOrJsxText::Expression( + Expression::OptionalMemberExpression(ast_expr::OptionalMemberExpression { + base: BaseNode::typed("OptionalMemberExpression"), + object: ome.object, + property: ome.property, + computed: ome.computed, + optional: *optional, + }), + )), + Expression::MemberExpression(me) => Ok(ExpressionOrJsxText::Expression( + Expression::OptionalMemberExpression(ast_expr::OptionalMemberExpression { + base: BaseNode::typed("OptionalMemberExpression"), + object: me.object, + property: me.property, + computed: me.computed, + optional: *optional, + }), + )), + other => Err(invariant_err( + &format!( + "Expected optional value to resolve to call or member expression, got {:?}", + std::mem::discriminant(&other) + ), + None, + )), + } + } + } +} + +fn codegen_base_instruction_value( + cx: &mut Context, + iv: &InstructionValue, +) -> Result { + match iv { + InstructionValue::Primitive { value, loc } => Ok(ExpressionOrJsxText::Expression( + codegen_primitive_value(value, *loc), + )), + InstructionValue::BinaryExpression { + operator, + left, + right, + .. + } => { + let left_expr = codegen_place_to_expression(cx, left)?; + let right_expr = codegen_place_to_expression(cx, right)?; + Ok(ExpressionOrJsxText::Expression( + Expression::BinaryExpression(ast_expr::BinaryExpression { + base: BaseNode::typed("BinaryExpression"), + operator: convert_binary_operator(operator), + left: Box::new(left_expr), + right: Box::new(right_expr), + }), + )) + } + InstructionValue::UnaryExpression { + operator, value, .. + } => { + let arg = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression( + Expression::UnaryExpression(ast_expr::UnaryExpression { + base: BaseNode::typed("UnaryExpression"), + operator: convert_unary_operator(operator), + prefix: true, + argument: Box::new(arg), + }), + )) + } + InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => { + let expr = codegen_place_to_expression(cx, place)?; + Ok(ExpressionOrJsxText::Expression(expr)) + } + InstructionValue::LoadGlobal { binding, .. } => Ok(ExpressionOrJsxText::Expression( + Expression::Identifier(make_identifier(binding.name())), + )), + InstructionValue::CallExpression { + callee, + args, + loc: _, + } => { + let callee_expr = codegen_place_to_expression(cx, callee)?; + let arguments = args + .iter() + .map(|arg| codegen_argument(cx, arg)) + .collect::>()?; + let call_expr = Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(callee_expr), + arguments, + type_parameters: None, + type_arguments: None, + optional: None, + }); + // enableEmitHookGuards: wrap hook calls in try/finally IIFE + let result = maybe_wrap_hook_call(cx, call_expr, callee.identifier); + Ok(ExpressionOrJsxText::Expression(result)) + } + InstructionValue::MethodCall { + receiver: _, + property, + args, + loc: _, + } => { + let member_expr = codegen_place_to_expression(cx, property)?; + // Invariant: MethodCall::property must resolve to a MemberExpression + if !matches!( + member_expr, + Expression::MemberExpression(_) | Expression::OptionalMemberExpression(_) + ) { + let expr_type = match &member_expr { + Expression::Identifier(_) => "Identifier", + _ => "unknown", + }; + { + let msg = format!("Got: '{}'", expr_type); + let mut err = CompilerError::new(); + err.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Invariant, + "[Codegen] Internal error: MethodCall::property must be an unpromoted \ + + unmemoized MemberExpression", + None, + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: property.loc, + message: Some(msg), + identifier_name: None, + }), + ); + return Err(err); + } + } + let arguments = args + .iter() + .map(|arg| codegen_argument(cx, arg)) + .collect::>()?; + let call_expr = Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(member_expr), + arguments, + type_parameters: None, + type_arguments: None, + optional: None, + }); + // enableEmitHookGuards: wrap hook method calls in try/finally IIFE + let result = maybe_wrap_hook_call(cx, call_expr, property.identifier); + Ok(ExpressionOrJsxText::Expression(result)) + } + InstructionValue::NewExpression { callee, args, .. } => { + let callee_expr = codegen_place_to_expression(cx, callee)?; + let arguments = args + .iter() + .map(|arg| codegen_argument(cx, arg)) + .collect::>()?; + Ok(ExpressionOrJsxText::Expression(Expression::NewExpression( + ast_expr::NewExpression { + base: BaseNode::typed("NewExpression"), + callee: Box::new(callee_expr), + arguments, + type_parameters: None, + type_arguments: None, + }, + ))) + } + InstructionValue::ArrayExpression { elements, .. } => { + let elems: Vec> = elements + .iter() + .map(|el| match el { + ArrayElement::Place(place) => Ok(Some(codegen_place_to_expression(cx, place)?)), + ArrayElement::Spread(spread) => { + let arg = codegen_place_to_expression(cx, &spread.place)?; + Ok(Some(Expression::SpreadElement(ast_expr::SpreadElement { + base: BaseNode::typed("SpreadElement"), + argument: Box::new(arg), + }))) + } + ArrayElement::Hole => Ok(None), + }) + .collect::>()?; + Ok(ExpressionOrJsxText::Expression( + Expression::ArrayExpression(ast_expr::ArrayExpression { + base: BaseNode::typed("ArrayExpression"), + elements: elems, + }), + )) + } + InstructionValue::ObjectExpression { properties, .. } => { + codegen_object_expression(cx, properties) + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + let obj = codegen_place_to_expression(cx, object)?; + let (prop, computed) = property_literal_to_expression(property); + Ok(ExpressionOrJsxText::Expression( + Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(obj), + property: Box::new(prop), + computed, + }), + )) + } + InstructionValue::PropertyStore { + object, + property, + value, + .. + } => { + let obj = codegen_place_to_expression(cx, object)?; + let (prop, computed) = property_literal_to_expression(property); + let val = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression( + Expression::AssignmentExpression(ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(obj), + property: Box::new(prop), + computed, + })), + right: Box::new(val), + }), + )) + } + InstructionValue::PropertyDelete { + object, property, .. + } => { + let obj = codegen_place_to_expression(cx, object)?; + let (prop, computed) = property_literal_to_expression(property); + Ok(ExpressionOrJsxText::Expression( + Expression::UnaryExpression(ast_expr::UnaryExpression { + base: BaseNode::typed("UnaryExpression"), + operator: AstUnaryOperator::Delete, + prefix: true, + argument: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(obj), + property: Box::new(prop), + computed, + })), + }), + )) + } + InstructionValue::ComputedLoad { + object, property, .. + } => { + let obj = codegen_place_to_expression(cx, object)?; + let prop = codegen_place_to_expression(cx, property)?; + Ok(ExpressionOrJsxText::Expression( + Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(obj), + property: Box::new(prop), + computed: true, + }), + )) + } + InstructionValue::ComputedStore { + object, + property, + value, + .. + } => { + let obj = codegen_place_to_expression(cx, object)?; + let prop = codegen_place_to_expression(cx, property)?; + let val = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression( + Expression::AssignmentExpression(ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(obj), + property: Box::new(prop), + computed: true, + })), + right: Box::new(val), + }), + )) + } + InstructionValue::ComputedDelete { + object, property, .. + } => { + let obj = codegen_place_to_expression(cx, object)?; + let prop = codegen_place_to_expression(cx, property)?; + Ok(ExpressionOrJsxText::Expression( + Expression::UnaryExpression(ast_expr::UnaryExpression { + base: BaseNode::typed("UnaryExpression"), + operator: AstUnaryOperator::Delete, + prefix: true, + argument: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(obj), + property: Box::new(prop), + computed: true, + })), + }), + )) + } + InstructionValue::RegExpLiteral { pattern, flags, .. } => Ok( + ExpressionOrJsxText::Expression(Expression::RegExpLiteral(AstRegExpLiteral { + base: BaseNode::typed("RegExpLiteral"), + pattern: pattern.clone(), + flags: flags.clone(), + })), + ), + InstructionValue::MetaProperty { meta, property, .. } => Ok( + ExpressionOrJsxText::Expression(Expression::MetaProperty(ast_expr::MetaProperty { + base: BaseNode::typed("MetaProperty"), + meta: make_identifier(meta), + property: make_identifier(property), + })), + ), + InstructionValue::Await { value, .. } => { + let arg = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression( + Expression::AwaitExpression(ast_expr::AwaitExpression { + base: BaseNode::typed("AwaitExpression"), + argument: Box::new(arg), + }), + )) + } + InstructionValue::GetIterator { collection, .. } => { + let expr = codegen_place_to_expression(cx, collection)?; + Ok(ExpressionOrJsxText::Expression(expr)) + } + InstructionValue::IteratorNext { iterator, .. } => { + let expr = codegen_place_to_expression(cx, iterator)?; + Ok(ExpressionOrJsxText::Expression(expr)) + } + InstructionValue::NextPropertyOf { value, .. } => { + let expr = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression(expr)) + } + InstructionValue::PostfixUpdate { + operation, lvalue, .. + } => { + let arg = codegen_place_to_expression(cx, lvalue)?; + Ok(ExpressionOrJsxText::Expression( + Expression::UpdateExpression(ast_expr::UpdateExpression { + base: BaseNode::typed("UpdateExpression"), + operator: convert_update_operator(operation), + argument: Box::new(arg), + prefix: false, + }), + )) + } + InstructionValue::PrefixUpdate { + operation, lvalue, .. + } => { + let arg = codegen_place_to_expression(cx, lvalue)?; + Ok(ExpressionOrJsxText::Expression( + Expression::UpdateExpression(ast_expr::UpdateExpression { + base: BaseNode::typed("UpdateExpression"), + operator: convert_update_operator(operation), + argument: Box::new(arg), + prefix: true, + }), + )) + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + invariant( + lvalue.kind == InstructionKind::Reassign, + "Unexpected StoreLocal in codegenInstructionValue", + None, + )?; + let lval = codegen_lvalue(cx, &LvalueRef::Place(&lvalue.place))?; + let rhs = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression( + Expression::AssignmentExpression(ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(lval), + right: Box::new(rhs), + }), + )) + } + InstructionValue::StoreGlobal { name, value, .. } => { + let rhs = codegen_place_to_expression(cx, value)?; + Ok(ExpressionOrJsxText::Expression( + Expression::AssignmentExpression(ast_expr::AssignmentExpression { + base: BaseNode::typed("AssignmentExpression"), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::Identifier(make_identifier(name))), + right: Box::new(rhs), + }), + )) + } + InstructionValue::FunctionExpression { + name, + name_hint, + lowered_func, + expr_type, + .. + } => codegen_function_expression(cx, name, name_hint, lowered_func, expr_type), + InstructionValue::TaggedTemplateExpression { tag, value, .. } => { + let tag_expr = codegen_place_to_expression(cx, tag)?; + Ok(ExpressionOrJsxText::Expression( + Expression::TaggedTemplateExpression(ast_expr::TaggedTemplateExpression { + base: BaseNode::typed("TaggedTemplateExpression"), + tag: Box::new(tag_expr), + quasi: ast_expr::TemplateLiteral { + base: BaseNode::typed("TemplateLiteral"), + quasis: vec![TemplateElement { + base: BaseNode::typed("TemplateElement"), + value: TemplateElementValue { + raw: value.raw.clone(), + cooked: value.cooked.clone(), + }, + tail: true, + }], + expressions: Vec::new(), + }, + type_parameters: None, + }), + )) + } + InstructionValue::TemplateLiteral { + subexprs, quasis, .. + } => { + let exprs: Vec = subexprs + .iter() + .map(|p| codegen_place_to_expression(cx, p)) + .collect::>()?; + let template_elems: Vec = quasis + .iter() + .enumerate() + .map(|(i, q)| TemplateElement { + base: BaseNode::typed("TemplateElement"), + value: TemplateElementValue { + raw: q.raw.clone(), + cooked: q.cooked.clone(), + }, + tail: i == quasis.len() - 1, + }) + .collect(); + Ok(ExpressionOrJsxText::Expression( + Expression::TemplateLiteral(ast_expr::TemplateLiteral { + base: BaseNode::typed("TemplateLiteral"), + quasis: template_elems, + expressions: exprs, + }), + )) + } + InstructionValue::TypeCastExpression { + value, + type_annotation_kind, + type_annotation, + .. + } => { + let expr = codegen_place_to_expression(cx, value)?; + // Wrap in the appropriate type cast expression if we have the + // original type annotation AST node + let wrapped = match (type_annotation_kind.as_deref(), type_annotation) { + (Some("satisfies"), Some(ta)) => { + Expression::TSSatisfiesExpression(ast_expr::TSSatisfiesExpression { + base: BaseNode::typed("TSSatisfiesExpression"), + expression: Box::new(expr), + type_annotation: ta.clone(), + }) + } + (Some("as"), Some(ta)) => Expression::TSAsExpression(ast_expr::TSAsExpression { + base: BaseNode::typed("TSAsExpression"), + expression: Box::new(expr), + type_annotation: ta.clone(), + }), + (Some("cast"), Some(ta)) => { + Expression::TypeCastExpression(ast_expr::TypeCastExpression { + base: BaseNode::typed("TypeCastExpression"), + expression: Box::new(expr), + type_annotation: ta.clone(), + }) + } + _ => expr, + }; + Ok(ExpressionOrJsxText::Expression(wrapped)) + } + InstructionValue::JSXText { value, loc } => Ok(ExpressionOrJsxText::JsxText(JSXText { + base: base_node_with_loc("JSXText", *loc), + value: value.clone(), + })), + InstructionValue::JsxExpression { + tag, + props, + children, + loc, + opening_loc, + closing_loc, + } => codegen_jsx_expression(cx, tag, props, children, *loc, *opening_loc, *closing_loc), + InstructionValue::JsxFragment { children, .. } => { + let child_elems: Vec = children + .iter() + .map(|child| codegen_jsx_element(cx, child)) + .collect::>()?; + Ok(ExpressionOrJsxText::Expression(Expression::JSXFragment( + JSXFragment { + base: BaseNode::typed("JSXFragment"), + opening_fragment: JSXOpeningFragment { + base: BaseNode::typed("JSXOpeningFragment"), + }, + closing_fragment: JSXClosingFragment { + base: BaseNode::typed("JSXClosingFragment"), + }, + children: child_elems, + }, + ))) + } + InstructionValue::UnsupportedNode { + original_node, + node_type, + .. + } => { + // Try to deserialize the original AST node from JSON (mirrors statement-level + // handler) + match original_node { + Some(node) => { + match serde_json::from_value::(node.clone()) { + Ok(expr) => Ok(ExpressionOrJsxText::Expression(expr)), + Err(_) => { + // Not a valid expression — fall back to placeholder + Ok(ExpressionOrJsxText::Expression(Expression::Identifier( + make_identifier(&format!( + "__unsupported_{}", + node_type.as_deref().unwrap_or("unknown") + )), + ))) + } + } + } + None => { + // No original node available — fall back to placeholder + Ok(ExpressionOrJsxText::Expression(Expression::Identifier( + make_identifier(&format!( + "__unsupported_{}", + node_type.as_deref().unwrap_or("unknown") + )), + ))) + } + } + } + InstructionValue::StartMemoize { .. } + | InstructionValue::FinishMemoize { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::DeclareLocal { .. } + | InstructionValue::DeclareContext { .. } + | InstructionValue::Destructure { .. } + | InstructionValue::ObjectMethod { .. } + | InstructionValue::StoreContext { .. } => Err(invariant_err( + &format!( + "Unexpected {:?} in codegenInstructionValue", + std::mem::discriminant(iv) + ), + None, + )), + } +} + +// ============================================================================= +// Function expression codegen +// ============================================================================= + +fn codegen_function_expression( + cx: &mut Context, + name: &Option, + name_hint: &Option, + lowered_func: &react_compiler_hir::LoweredFunction, + expr_type: &FunctionExpressionType, +) -> Result { + let func = &cx.env.functions[lowered_func.func.0 as usize]; + let reactive_fn = build_reactive_function(func, cx.env)?; + let mut reactive_fn_mut = reactive_fn; + prune_unused_labels(&mut reactive_fn_mut, cx.env)?; + prune_unused_lvalues(&mut reactive_fn_mut, cx.env); + prune_hoisted_contexts(&mut reactive_fn_mut, cx.env)?; + + let mut inner_cx = Context::new( + cx.env, + reactive_fn_mut + .id + .as_deref() + .unwrap_or("[[ anonymous ]]") + .to_string(), + cx.unique_identifiers.clone(), + cx.fbt_operands.clone(), + ); + inner_cx.temp = cx.temp.clone(); + + let fn_result = codegen_reactive_function(&mut inner_cx, &reactive_fn_mut)?; + + let value = match expr_type { + FunctionExpressionType::ArrowFunctionExpression => { + let mut body: ArrowFunctionBody = + ArrowFunctionBody::BlockStatement(fn_result.body.clone()); + // Optimize single-return arrow functions + if fn_result.body.body.len() == 1 && reactive_fn_mut.directives.is_empty() { + if let Statement::ReturnStatement(ret) = &fn_result.body.body[0] { + if let Some(ref arg) = ret.argument { + body = ArrowFunctionBody::Expression(arg.clone()); + } + } + } + let is_expression = matches!(body, ArrowFunctionBody::Expression(_)); + Expression::ArrowFunctionExpression(ast_expr::ArrowFunctionExpression { + base: BaseNode::typed("ArrowFunctionExpression"), + params: fn_result.params, + body: Box::new(body), + id: None, + generator: false, + is_async: fn_result.is_async, + expression: Some(is_expression), + return_type: None, + type_parameters: None, + predicate: None, + }) + } + _ => Expression::FunctionExpression(ast_expr::FunctionExpression { + base: BaseNode::typed("FunctionExpression"), + params: fn_result.params, + body: fn_result.body, + id: name.as_ref().map(|n| make_identifier(n)), + generator: fn_result.generator, + is_async: fn_result.is_async, + return_type: None, + type_parameters: None, + }), + }; + + // Handle enableNameAnonymousFunctions + if cx.env.config.enable_name_anonymous_functions && name.is_none() && name_hint.is_some() { + let hint = name_hint.as_ref().unwrap(); + let wrapped = Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::ObjectExpression(ast_expr::ObjectExpression { + base: BaseNode::typed("ObjectExpression"), + properties: vec![ast_expr::ObjectExpressionProperty::ObjectProperty( + ast_expr::ObjectProperty { + base: BaseNode::typed("ObjectProperty"), + key: Box::new(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: hint.clone(), + })), + value: Box::new(value), + computed: false, + shorthand: false, + decorators: None, + method: None, + }, + )], + })), + property: Box::new(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: hint.clone(), + })), + computed: true, + }); + return Ok(ExpressionOrJsxText::Expression(wrapped)); + } + + Ok(ExpressionOrJsxText::Expression(value)) +} + +// ============================================================================= +// Object expression codegen +// ============================================================================= + +fn codegen_object_expression( + cx: &mut Context, + properties: &[ObjectPropertyOrSpread], +) -> Result { + let mut ast_properties: Vec = Vec::new(); + for prop in properties { + match prop { + ObjectPropertyOrSpread::Property(obj_prop) => { + let key = codegen_object_property_key(cx, &obj_prop.key)?; + match obj_prop.property_type { + ObjectPropertyType::Property => { + let value = codegen_place_to_expression(cx, &obj_prop.place)?; + let is_shorthand = matches!(&key, Expression::Identifier(k_id) + if matches!(&value, Expression::Identifier(v_id) if v_id.name == k_id.name)); + ast_properties.push(ast_expr::ObjectExpressionProperty::ObjectProperty( + ast_expr::ObjectProperty { + base: BaseNode::typed("ObjectProperty"), + key: Box::new(key), + value: Box::new(value), + computed: matches!( + obj_prop.key, + ObjectPropertyKey::Computed { .. } + ), + shorthand: is_shorthand, + decorators: None, + method: None, + }, + )); + } + ObjectPropertyType::Method => { + let method_data = cx.object_methods.get(&obj_prop.place.identifier); + let method_data = method_data.cloned(); + let Some((InstructionValue::ObjectMethod { lowered_func, .. }, _)) = + method_data + else { + return Err(invariant_err("Expected ObjectMethod instruction", None)); + }; + + let func = &cx.env.functions[lowered_func.func.0 as usize]; + let reactive_fn = build_reactive_function(func, cx.env)?; + let mut reactive_fn_mut = reactive_fn; + prune_unused_labels(&mut reactive_fn_mut, cx.env)?; + prune_unused_lvalues(&mut reactive_fn_mut, cx.env); + + let mut inner_cx = Context::new( + cx.env, + reactive_fn_mut + .id + .as_deref() + .unwrap_or("[[ anonymous ]]") + .to_string(), + cx.unique_identifiers.clone(), + cx.fbt_operands.clone(), + ); + inner_cx.temp = cx.temp.clone(); + + let fn_result = codegen_reactive_function(&mut inner_cx, &reactive_fn_mut)?; + + ast_properties.push(ast_expr::ObjectExpressionProperty::ObjectMethod( + ast_expr::ObjectMethod { + base: BaseNode::typed("ObjectMethod"), + method: true, + kind: ast_expr::ObjectMethodKind::Method, + key: Box::new(key), + params: fn_result.params, + body: fn_result.body, + computed: matches!( + obj_prop.key, + ObjectPropertyKey::Computed { .. } + ), + id: None, + generator: fn_result.generator, + is_async: fn_result.is_async, + decorators: None, + return_type: None, + type_parameters: None, + }, + )); + } + } + } + ObjectPropertyOrSpread::Spread(spread) => { + let arg = codegen_place_to_expression(cx, &spread.place)?; + ast_properties.push(ast_expr::ObjectExpressionProperty::SpreadElement( + ast_expr::SpreadElement { + base: BaseNode::typed("SpreadElement"), + argument: Box::new(arg), + }, + )); + } + } + } + Ok(ExpressionOrJsxText::Expression( + Expression::ObjectExpression(ast_expr::ObjectExpression { + base: BaseNode::typed("ObjectExpression"), + properties: ast_properties, + }), + )) +} + +fn codegen_object_property_key( + cx: &mut Context, + key: &ObjectPropertyKey, +) -> Result { + match key { + ObjectPropertyKey::String { name } => Ok(Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: name.clone(), + })), + ObjectPropertyKey::Identifier { name } => Ok(Expression::Identifier(make_identifier(name))), + ObjectPropertyKey::Computed { name } => { + let expr = codegen_place(cx, name)?; + match expr { + ExpressionOrJsxText::Expression(e) => Ok(e), + ExpressionOrJsxText::JsxText(_) => Err(invariant_err( + "Expected object property key to be an expression", + None, + )), + } + } + ObjectPropertyKey::Number { name } => Ok(Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: name.value(), + })), + } +} + +// ============================================================================= +// JSX codegen +// ============================================================================= + +fn codegen_jsx_expression( + cx: &mut Context, + tag: &JsxTag, + props: &[JsxAttribute], + children: &Option>, + loc: Option, + opening_loc: Option, + closing_loc: Option, +) -> Result { + let mut attributes: Vec = Vec::new(); + for attr in props { + attributes.push(codegen_jsx_attribute(cx, attr)?); + } + + let (tag_value, _tag_loc) = match tag { + JsxTag::Place(place) => (codegen_place_to_expression(cx, place)?, place.loc), + JsxTag::Builtin(builtin) => ( + Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: builtin.name.clone(), + }), + None, + ), + }; + + let jsx_tag = expression_to_jsx_tag(&tag_value, jsx_tag_loc(tag))?; + + let is_fbt_tag = if let Expression::StringLiteral(ref s) = tag_value { + SINGLE_CHILD_FBT_TAGS.contains(&s.value.as_str()) + } else { + false + }; + + let child_nodes = if is_fbt_tag { + children + .as_ref() + .map(|c| { + c.iter() + .map(|child| codegen_jsx_fbt_child_element(cx, child)) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default() + } else { + children + .as_ref() + .map(|c| { + c.iter() + .map(|child| codegen_jsx_element(cx, child)) + .collect::, _>>() + }) + .transpose()? + .unwrap_or_default() + }; + + let is_self_closing = children.is_none(); + + let element = JSXElement { + base: base_node_with_loc("JSXElement", loc), + opening_element: JSXOpeningElement { + base: base_node_with_loc("JSXOpeningElement", opening_loc), + name: jsx_tag.clone(), + attributes, + self_closing: is_self_closing, + type_parameters: None, + }, + closing_element: if !is_self_closing { + Some(JSXClosingElement { + base: base_node_with_loc("JSXClosingElement", closing_loc), + name: jsx_tag, + }) + } else { + None + }, + children: child_nodes, + self_closing: if is_self_closing { Some(true) } else { None }, + }; + + Ok(ExpressionOrJsxText::Expression(Expression::JSXElement( + Box::new(element), + ))) +} + +const JSX_TEXT_CHILD_REQUIRES_EXPR_CONTAINER_PATTERN: &[char] = &['<', '>', '&', '{', '}']; +const STRING_REQUIRES_EXPR_CONTAINER_CHARS: &str = "\"\\"; + +fn string_requires_expr_container(s: &str) -> bool { + for c in s.chars() { + if STRING_REQUIRES_EXPR_CONTAINER_CHARS.contains(c) { + return true; + } + // Check for control chars and non-basic-latin + let code = c as u32; + if code <= 0x1f || code == 0x7f || (code >= 0x80 && code <= 0x9f) || (code >= 0xa0) { + return true; + } + } + false +} + +fn codegen_jsx_attribute( + cx: &mut Context, + attr: &JsxAttribute, +) -> Result { + match attr { + JsxAttribute::Attribute { name, place } => { + let prop_name = if name.contains(':') { + let parts: Vec<&str> = name.splitn(2, ':').collect(); + JSXAttributeName::JSXNamespacedName(JSXNamespacedName { + base: BaseNode::typed("JSXNamespacedName"), + namespace: JSXIdentifier { + base: BaseNode::typed("JSXIdentifier"), + name: parts[0].to_string(), + }, + name: JSXIdentifier { + base: BaseNode::typed("JSXIdentifier"), + name: parts[1].to_string(), + }, + }) + } else { + JSXAttributeName::JSXIdentifier(JSXIdentifier { + base: BaseNode::typed("JSXIdentifier"), + name: name.clone(), + }) + }; + + let inner_value = codegen_place_to_expression(cx, place)?; + let attr_value = match &inner_value { + Expression::StringLiteral(s) => { + if string_requires_expr_container(&s.value) + && !cx.fbt_operands.contains(&place.identifier) + { + Some(JSXAttributeValue::JSXExpressionContainer( + JSXExpressionContainer { + base: base_node_with_loc("JSXExpressionContainer", place.loc), + expression: JSXExpressionContainerExpr::Expression(Box::new( + inner_value, + )), + }, + )) + } else { + // Preserve loc from the inner StringLiteral (or fall back to + // the place's loc) so downstream plugins (e.g., babel-plugin-fbt) + // can read loc on attribute values. + let base = if s.base.loc.is_some() { + s.base.clone() + } else { + base_node_with_loc("StringLiteral", place.loc) + }; + Some(JSXAttributeValue::StringLiteral(StringLiteral { + base, + value: s.value.clone(), + })) + } + } + _ => Some(JSXAttributeValue::JSXExpressionContainer( + JSXExpressionContainer { + base: base_node_with_loc("JSXExpressionContainer", place.loc), + expression: JSXExpressionContainerExpr::Expression(Box::new(inner_value)), + }, + )), + }; + Ok(JSXAttributeItem::JSXAttribute(AstJSXAttribute { + base: base_node_with_loc("JSXAttribute", place.loc), + name: prop_name, + value: attr_value, + })) + } + JsxAttribute::SpreadAttribute { argument } => { + let expr = codegen_place_to_expression(cx, argument)?; + Ok(JSXAttributeItem::JSXSpreadAttribute(JSXSpreadAttribute { + base: BaseNode::typed("JSXSpreadAttribute"), + argument: Box::new(expr), + })) + } + } +} + +fn codegen_jsx_element(cx: &mut Context, place: &Place) -> Result { + let loc = place.loc; + let value = codegen_place(cx, place)?; + match value { + ExpressionOrJsxText::JsxText(text) => { + if text + .value + .contains(JSX_TEXT_CHILD_REQUIRES_EXPR_CONTAINER_PATTERN) + { + Ok(JSXChild::JSXExpressionContainer(JSXExpressionContainer { + base: base_node_with_loc("JSXExpressionContainer", loc), + expression: JSXExpressionContainerExpr::Expression(Box::new( + Expression::StringLiteral(StringLiteral { + base: base_node_with_loc("StringLiteral", loc), + value: text.value.clone(), + }), + )), + })) + } else { + Ok(JSXChild::JSXText(text)) + } + } + ExpressionOrJsxText::Expression(Expression::JSXElement(elem)) => { + Ok(JSXChild::JSXElement(elem)) + } + ExpressionOrJsxText::Expression(Expression::JSXFragment(frag)) => { + Ok(JSXChild::JSXFragment(frag)) + } + ExpressionOrJsxText::Expression(expr) => { + Ok(JSXChild::JSXExpressionContainer(JSXExpressionContainer { + base: base_node_with_loc("JSXExpressionContainer", loc), + expression: JSXExpressionContainerExpr::Expression(Box::new(expr)), + })) + } + } +} + +fn codegen_jsx_fbt_child_element( + cx: &mut Context, + place: &Place, +) -> Result { + let loc = place.loc; + let value = codegen_place(cx, place)?; + match value { + ExpressionOrJsxText::JsxText(text) => Ok(JSXChild::JSXText(text)), + ExpressionOrJsxText::Expression(Expression::JSXElement(elem)) => { + Ok(JSXChild::JSXElement(elem)) + } + ExpressionOrJsxText::Expression(expr) => { + Ok(JSXChild::JSXExpressionContainer(JSXExpressionContainer { + base: base_node_with_loc("JSXExpressionContainer", loc), + expression: JSXExpressionContainerExpr::Expression(Box::new(expr)), + })) + } + } +} + +fn expression_to_jsx_tag( + expr: &Expression, + loc: Option, +) -> Result { + match expr { + Expression::Identifier(ident) => Ok(JSXElementName::JSXIdentifier(JSXIdentifier { + base: base_node_with_loc("JSXIdentifier", loc), + name: ident.name.clone(), + })), + Expression::MemberExpression(me) => Ok(JSXElementName::JSXMemberExpression( + convert_member_expression_to_jsx(me)?, + )), + Expression::StringLiteral(s) => { + if s.value.contains(':') { + let parts: Vec<&str> = s.value.splitn(2, ':').collect(); + Ok(JSXElementName::JSXNamespacedName(JSXNamespacedName { + base: base_node_with_loc("JSXNamespacedName", loc), + namespace: JSXIdentifier { + base: base_node_with_loc("JSXIdentifier", loc), + name: parts[0].to_string(), + }, + name: JSXIdentifier { + base: base_node_with_loc("JSXIdentifier", loc), + name: parts[1].to_string(), + }, + })) + } else { + Ok(JSXElementName::JSXIdentifier(JSXIdentifier { + base: base_node_with_loc("JSXIdentifier", loc), + name: s.value.clone(), + })) + } + } + _ => Err(invariant_err( + &format!("Expected JSX tag to be an identifier or string"), + None, + )), + } +} + +fn convert_member_expression_to_jsx( + me: &ast_expr::MemberExpression, +) -> Result { + let Expression::Identifier(ref prop_ident) = *me.property else { + return Err(invariant_err( + "Expected JSX member expression property to be a string", + None, + )); + }; + let property = JSXIdentifier { + base: BaseNode::typed("JSXIdentifier"), + name: prop_ident.name.clone(), + }; + match &*me.object { + Expression::Identifier(ident) => Ok(JSXMemberExpression { + base: BaseNode::typed("JSXMemberExpression"), + object: Box::new(JSXMemberExprObject::JSXIdentifier(JSXIdentifier { + base: BaseNode::typed("JSXIdentifier"), + name: ident.name.clone(), + })), + property, + }), + Expression::MemberExpression(inner_me) => { + let inner = convert_member_expression_to_jsx(inner_me)?; + Ok(JSXMemberExpression { + base: BaseNode::typed("JSXMemberExpression"), + object: Box::new(JSXMemberExprObject::JSXMemberExpression(Box::new(inner))), + property, + }) + } + _ => Err(invariant_err( + "Expected JSX member expression to be an identifier or nested member expression", + None, + )), + } +} + +// ============================================================================= +// Pattern codegen (lvalues) +// ============================================================================= + +enum LvalueRef<'a> { + Place(&'a Place), + Pattern(&'a Pattern), + Spread(&'a SpreadPattern), +} + +fn codegen_lvalue(cx: &mut Context, pattern: &LvalueRef) -> Result { + match pattern { + LvalueRef::Place(place) => Ok(PatternLike::Identifier(convert_identifier( + place.identifier, + cx.env, + )?)), + LvalueRef::Pattern(pat) => match pat { + Pattern::Array(arr) => codegen_array_pattern(cx, arr), + Pattern::Object(obj) => codegen_object_pattern(cx, obj), + }, + LvalueRef::Spread(spread) => { + let inner = codegen_lvalue(cx, &LvalueRef::Place(&spread.place))?; + Ok(PatternLike::RestElement(RestElement { + base: BaseNode::typed("RestElement"), + argument: Box::new(inner), + type_annotation: None, + decorators: None, + })) + } + } +} + +fn codegen_array_pattern( + cx: &mut Context, + pattern: &ArrayPattern, +) -> Result { + let elements: Vec> = pattern + .items + .iter() + .map(|item| match item { + react_compiler_hir::ArrayPatternElement::Place(place) => { + Ok(Some(codegen_lvalue(cx, &LvalueRef::Place(place))?)) + } + react_compiler_hir::ArrayPatternElement::Spread(spread) => { + Ok(Some(codegen_lvalue(cx, &LvalueRef::Spread(spread))?)) + } + react_compiler_hir::ArrayPatternElement::Hole => Ok(None), + }) + .collect::>()?; + Ok(PatternLike::ArrayPattern(AstArrayPattern { + base: base_node_with_loc("ArrayPattern", pattern.loc), + elements, + type_annotation: None, + decorators: None, + })) +} + +fn codegen_object_pattern( + cx: &mut Context, + pattern: &ObjectPattern, +) -> Result { + let properties: Vec = pattern + .properties + .iter() + .map(|prop| match prop { + ObjectPropertyOrSpread::Property(obj_prop) => { + let key = codegen_object_property_key(cx, &obj_prop.key)?; + let value = codegen_lvalue(cx, &LvalueRef::Place(&obj_prop.place))?; + let is_shorthand = matches!(&key, Expression::Identifier(k_id) + if matches!(&value, PatternLike::Identifier(v_id) if v_id.name == k_id.name)); + Ok(ObjectPatternProperty::ObjectProperty(ObjectPatternProp { + base: BaseNode::typed("ObjectProperty"), + key: Box::new(key), + value: Box::new(value), + computed: matches!(obj_prop.key, ObjectPropertyKey::Computed { .. }), + shorthand: is_shorthand, + decorators: None, + method: None, + })) + } + ObjectPropertyOrSpread::Spread(spread) => { + let inner = codegen_lvalue(cx, &LvalueRef::Place(&spread.place))?; + Ok(ObjectPatternProperty::RestElement(RestElement { + base: BaseNode::typed("RestElement"), + argument: Box::new(inner), + type_annotation: None, + decorators: None, + })) + } + }) + .collect::>()?; + Ok(PatternLike::ObjectPattern( + react_compiler_ast::patterns::ObjectPattern { + base: base_node_with_loc("ObjectPattern", pattern.loc), + properties, + type_annotation: None, + decorators: None, + }, + )) +} + +// ============================================================================= +// Place / identifier codegen +// ============================================================================= + +fn codegen_place_to_expression( + cx: &mut Context, + place: &Place, +) -> Result { + let value = codegen_place(cx, place)?; + Ok(convert_value_to_expression(value)) +} + +fn codegen_place(cx: &mut Context, place: &Place) -> Result { + let ident = &cx.env.identifiers[place.identifier.0 as usize]; + if let Some(tmp) = cx.temp.get(&ident.declaration_id) { + if let Some(val) = tmp { + return Ok(val.clone()); + } + // tmp is None — means declared but no temp value, fall through + } + // Check if it's an unnamed identifier without a temp + if ident.name.is_none() && !cx.temp.contains_key(&ident.declaration_id) { + return Err(invariant_err( + &format!( + "[Codegen] No value found for temporary, identifier id={}", + place.identifier.0 + ), + place.loc, + )); + } + let mut ast_ident = convert_identifier(place.identifier, cx.env)?; + // Override identifier loc with place.loc, matching TS: identifier.loc = + // place.loc + if let Some(loc) = place.loc { + ast_ident.base.loc = Some(AstSourceLocation { + start: AstPosition { + line: loc.start.line, + column: loc.start.column, + index: None, + }, + end: AstPosition { + line: loc.end.line, + column: loc.end.column, + index: None, + }, + filename: None, + identifier_name: None, + }); + } + Ok(ExpressionOrJsxText::Expression(Expression::Identifier( + ast_ident, + ))) +} + +fn convert_identifier( + identifier_id: IdentifierId, + env: &Environment, +) -> Result { + let ident = &env.identifiers[identifier_id.0 as usize]; + let name = match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => n.clone(), + Some(react_compiler_hir::IdentifierName::Promoted(n)) => n.clone(), + None => { + // Use CompilerDiagnostic (with details array) to match TS + // CompilerError.invariant() which creates a CompilerDiagnostic with + // details: [{kind: "error", loc, message}]. + let reason = "Expected temporaries to be promoted to named identifiers in an earlier \ + pass" + .to_string(); + let description = format!("identifier {} is unnamed", identifier_id.0); + let mut err = CompilerError::new(); + err.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Invariant, + reason.clone(), + Some(description), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: None, + message: Some(reason), + identifier_name: None, + }), + ); + return Err(err); + } + }; + Ok(make_identifier_with_loc(&name, ident.loc)) +} + +fn codegen_argument(cx: &mut Context, arg: &PlaceOrSpread) -> Result { + match arg { + PlaceOrSpread::Place(place) => codegen_place_to_expression(cx, place), + PlaceOrSpread::Spread(spread) => { + let expr = codegen_place_to_expression(cx, &spread.place)?; + Ok(Expression::SpreadElement(ast_expr::SpreadElement { + base: BaseNode::typed("SpreadElement"), + argument: Box::new(expr), + })) + } + } +} + +// ============================================================================= +// Dependency codegen +// ============================================================================= + +fn codegen_dependency( + cx: &mut Context, + dep: &react_compiler_hir::ReactiveScopeDependency, +) -> Result { + let mut object: Expression = + Expression::Identifier(convert_identifier(dep.identifier, cx.env)?); + if !dep.path.is_empty() { + let has_optional = dep.path.iter().any(|p| p.optional); + for path_entry in &dep.path { + let (property, is_computed) = property_literal_to_expression(&path_entry.property); + if has_optional { + object = Expression::OptionalMemberExpression(ast_expr::OptionalMemberExpression { + base: BaseNode::typed("OptionalMemberExpression"), + object: Box::new(object), + property: Box::new(property), + computed: is_computed, + optional: path_entry.optional, + }); + } else { + object = Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(object), + property: Box::new(property), + computed: is_computed, + }); + } + } + } + Ok(object) +} + +// ============================================================================= +// CountMemoBlockVisitor — uses ReactiveFunctionVisitor trait +// ============================================================================= + +/// Counts memo blocks and pruned memo blocks in a reactive function. +/// TS: `class CountMemoBlockVisitor extends ReactiveFunctionVisitor` +struct CountMemoBlockVisitor<'a> { + env: &'a Environment, +} + +struct CountMemoBlockState { + memo_blocks: u32, + memo_values: u32, + pruned_memo_blocks: u32, + pruned_memo_values: u32, +} + +impl<'a> ReactiveFunctionVisitor for CountMemoBlockVisitor<'a> { + type State = CountMemoBlockState; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_scope(&self, scope_block: &ReactiveScopeBlock, state: &mut CountMemoBlockState) { + state.memo_blocks += 1; + let scope = &self.env.scopes[scope_block.scope.0 as usize]; + state.memo_values += scope.declarations.len() as u32; + self.traverse_scope(scope_block, state); + } + + fn visit_pruned_scope( + &self, + scope_block: &PrunedReactiveScopeBlock, + state: &mut CountMemoBlockState, + ) { + state.pruned_memo_blocks += 1; + let scope = &self.env.scopes[scope_block.scope.0 as usize]; + state.pruned_memo_values += scope.declarations.len() as u32; + self.traverse_pruned_scope(scope_block, state); + } +} + +fn count_memo_blocks(func: &ReactiveFunction, env: &Environment) -> (u32, u32, u32, u32) { + let visitor = CountMemoBlockVisitor { env }; + let mut state = CountMemoBlockState { + memo_blocks: 0, + memo_values: 0, + pruned_memo_blocks: 0, + pruned_memo_values: 0, + }; + visit_reactive_function(func, &visitor, &mut state); + ( + state.memo_blocks, + state.memo_values, + state.pruned_memo_blocks, + state.pruned_memo_values, + ) +} + +// ============================================================================= +// Operator conversions +// ============================================================================= + +fn convert_binary_operator(op: &react_compiler_hir::BinaryOperator) -> AstBinaryOperator { + match op { + react_compiler_hir::BinaryOperator::Equal => AstBinaryOperator::Eq, + react_compiler_hir::BinaryOperator::NotEqual => AstBinaryOperator::Neq, + react_compiler_hir::BinaryOperator::StrictEqual => AstBinaryOperator::StrictEq, + react_compiler_hir::BinaryOperator::StrictNotEqual => AstBinaryOperator::StrictNeq, + react_compiler_hir::BinaryOperator::LessThan => AstBinaryOperator::Lt, + react_compiler_hir::BinaryOperator::LessEqual => AstBinaryOperator::Lte, + react_compiler_hir::BinaryOperator::GreaterThan => AstBinaryOperator::Gt, + react_compiler_hir::BinaryOperator::GreaterEqual => AstBinaryOperator::Gte, + react_compiler_hir::BinaryOperator::ShiftLeft => AstBinaryOperator::Shl, + react_compiler_hir::BinaryOperator::ShiftRight => AstBinaryOperator::Shr, + react_compiler_hir::BinaryOperator::UnsignedShiftRight => AstBinaryOperator::UShr, + react_compiler_hir::BinaryOperator::Add => AstBinaryOperator::Add, + react_compiler_hir::BinaryOperator::Subtract => AstBinaryOperator::Sub, + react_compiler_hir::BinaryOperator::Multiply => AstBinaryOperator::Mul, + react_compiler_hir::BinaryOperator::Divide => AstBinaryOperator::Div, + react_compiler_hir::BinaryOperator::Modulo => AstBinaryOperator::Rem, + react_compiler_hir::BinaryOperator::Exponent => AstBinaryOperator::Exp, + react_compiler_hir::BinaryOperator::BitwiseOr => AstBinaryOperator::BitOr, + react_compiler_hir::BinaryOperator::BitwiseXor => AstBinaryOperator::BitXor, + react_compiler_hir::BinaryOperator::BitwiseAnd => AstBinaryOperator::BitAnd, + react_compiler_hir::BinaryOperator::In => AstBinaryOperator::In, + react_compiler_hir::BinaryOperator::InstanceOf => AstBinaryOperator::Instanceof, + } +} + +fn convert_unary_operator(op: &react_compiler_hir::UnaryOperator) -> AstUnaryOperator { + match op { + react_compiler_hir::UnaryOperator::Minus => AstUnaryOperator::Neg, + react_compiler_hir::UnaryOperator::Plus => AstUnaryOperator::Plus, + react_compiler_hir::UnaryOperator::Not => AstUnaryOperator::Not, + react_compiler_hir::UnaryOperator::BitwiseNot => AstUnaryOperator::BitNot, + react_compiler_hir::UnaryOperator::TypeOf => AstUnaryOperator::TypeOf, + react_compiler_hir::UnaryOperator::Void => AstUnaryOperator::Void, + } +} + +fn convert_logical_operator(op: &LogicalOperator) -> AstLogicalOperator { + match op { + LogicalOperator::And => AstLogicalOperator::And, + LogicalOperator::Or => AstLogicalOperator::Or, + LogicalOperator::NullishCoalescing => AstLogicalOperator::NullishCoalescing, + } +} + +fn convert_update_operator(op: &react_compiler_hir::UpdateOperator) -> AstUpdateOperator { + match op { + react_compiler_hir::UpdateOperator::Increment => AstUpdateOperator::Increment, + react_compiler_hir::UpdateOperator::Decrement => AstUpdateOperator::Decrement, + } +} + +// ============================================================================= +// Helpers +// ============================================================================= + +/// Create a BaseNode with the given type name and optional source location. +/// Converts from the diagnostics SourceLocation (line, column) to the AST +/// SourceLocation format. This is critical for Babel's `retainLines: true` +/// option to insert blank lines at correct positions. +fn base_node_with_loc(type_name: &str, loc: Option) -> BaseNode { + match loc { + Some(loc) => BaseNode { + node_type: Some(type_name.to_string()), + loc: Some(AstSourceLocation { + start: AstPosition { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: AstPosition { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + filename: None, + identifier_name: None, + }), + ..Default::default() + }, + None => BaseNode::typed(type_name), + } +} + +fn make_identifier(name: &str) -> AstIdentifier { + AstIdentifier { + base: BaseNode::typed("Identifier"), + name: name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + } +} + +fn make_identifier_with_loc(name: &str, loc: Option) -> AstIdentifier { + AstIdentifier { + base: base_node_with_loc("Identifier", loc), + name: name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + } +} + +fn make_var_declarator(id: PatternLike, init: Option) -> VariableDeclarator { + // Reconstruct VariableDeclarator.loc from id.loc.start and init.loc.end, + // matching TS createVariableDeclarator behavior for retainLines support. + let loc = get_pattern_loc(&id).and_then(|id_loc| { + let end = match &init { + Some(expr) => get_expression_loc(expr) + .map(|l| l.end.clone()) + .unwrap_or_else(|| id_loc.end.clone()), + None => id_loc.end.clone(), + }; + Some(AstSourceLocation { + start: id_loc.start.clone(), + end, + filename: id_loc.filename.clone(), + identifier_name: None, + }) + }); + VariableDeclarator { + base: if let Some(loc) = loc { + BaseNode { + node_type: Some("VariableDeclarator".to_string()), + loc: Some(loc), + ..Default::default() + } + } else { + BaseNode::typed("VariableDeclarator") + }, + id, + init: init.map(Box::new), + definite: None, + } +} + +/// Extract the loc from a PatternLike's base node. +fn get_pattern_loc(pattern: &PatternLike) -> Option<&AstSourceLocation> { + match pattern { + PatternLike::Identifier(id) => id.base.loc.as_ref(), + PatternLike::ObjectPattern(p) => p.base.loc.as_ref(), + PatternLike::ArrayPattern(p) => p.base.loc.as_ref(), + PatternLike::AssignmentPattern(p) => p.base.loc.as_ref(), + PatternLike::RestElement(p) => p.base.loc.as_ref(), + _ => None, + } +} + +/// Extract the loc from an Expression's base node. +fn get_expression_loc(expr: &Expression) -> Option<&AstSourceLocation> { + match expr { + Expression::Identifier(e) => e.base.loc.as_ref(), + Expression::StringLiteral(e) => e.base.loc.as_ref(), + Expression::NumericLiteral(e) => e.base.loc.as_ref(), + Expression::BooleanLiteral(e) => e.base.loc.as_ref(), + Expression::NullLiteral(e) => e.base.loc.as_ref(), + Expression::CallExpression(e) => e.base.loc.as_ref(), + Expression::MemberExpression(e) => e.base.loc.as_ref(), + Expression::OptionalMemberExpression(e) => e.base.loc.as_ref(), + Expression::ArrayExpression(e) => e.base.loc.as_ref(), + Expression::ObjectExpression(e) => e.base.loc.as_ref(), + Expression::ArrowFunctionExpression(e) => e.base.loc.as_ref(), + Expression::FunctionExpression(e) => e.base.loc.as_ref(), + Expression::BinaryExpression(e) => e.base.loc.as_ref(), + Expression::UnaryExpression(e) => e.base.loc.as_ref(), + Expression::UpdateExpression(e) => e.base.loc.as_ref(), + Expression::LogicalExpression(e) => e.base.loc.as_ref(), + Expression::ConditionalExpression(e) => e.base.loc.as_ref(), + Expression::SequenceExpression(e) => e.base.loc.as_ref(), + Expression::AssignmentExpression(e) => e.base.loc.as_ref(), + Expression::TemplateLiteral(e) => e.base.loc.as_ref(), + Expression::TaggedTemplateExpression(e) => e.base.loc.as_ref(), + Expression::SpreadElement(e) => e.base.loc.as_ref(), + Expression::RegExpLiteral(e) => e.base.loc.as_ref(), + Expression::JSXElement(e) => e.base.loc.as_ref(), + Expression::JSXFragment(e) => e.base.loc.as_ref(), + Expression::NewExpression(e) => e.base.loc.as_ref(), + Expression::OptionalCallExpression(e) => e.base.loc.as_ref(), + _ => None, + } +} + +/// Apply a source location to an ExpressionOrJsxText value, matching the TS +/// behavior where `value.loc = instrValue.loc` is set at the end of +/// codegenInstructionValue. +fn apply_loc_to_value(value: &mut ExpressionOrJsxText, loc: DiagSourceLocation) { + let ast_loc = AstSourceLocation { + start: AstPosition { + line: loc.start.line, + column: loc.start.column, + index: None, + }, + end: AstPosition { + line: loc.end.line, + column: loc.end.column, + index: None, + }, + filename: None, + identifier_name: None, + }; + match value { + ExpressionOrJsxText::Expression(expr) => { + apply_loc_to_expression(expr, ast_loc); + } + ExpressionOrJsxText::JsxText(text) => { + text.base.loc = Some(ast_loc); + } + } +} + +/// Apply a source location to an Expression's base node. +fn apply_loc_to_expression(expr: &mut Expression, loc: AstSourceLocation) { + let base = match expr { + Expression::Identifier(e) => &mut e.base, + Expression::StringLiteral(e) => &mut e.base, + Expression::NumericLiteral(e) => &mut e.base, + Expression::BooleanLiteral(e) => &mut e.base, + Expression::NullLiteral(e) => &mut e.base, + Expression::CallExpression(e) => &mut e.base, + Expression::MemberExpression(e) => &mut e.base, + Expression::OptionalMemberExpression(e) => &mut e.base, + Expression::ArrayExpression(e) => &mut e.base, + Expression::ObjectExpression(e) => &mut e.base, + Expression::ArrowFunctionExpression(e) => &mut e.base, + Expression::FunctionExpression(e) => &mut e.base, + Expression::BinaryExpression(e) => &mut e.base, + Expression::UnaryExpression(e) => &mut e.base, + Expression::UpdateExpression(e) => &mut e.base, + Expression::LogicalExpression(e) => &mut e.base, + Expression::ConditionalExpression(e) => &mut e.base, + Expression::SequenceExpression(e) => &mut e.base, + Expression::AssignmentExpression(e) => &mut e.base, + Expression::TemplateLiteral(e) => &mut e.base, + Expression::TaggedTemplateExpression(e) => &mut e.base, + Expression::SpreadElement(e) => &mut e.base, + Expression::RegExpLiteral(e) => &mut e.base, + Expression::JSXElement(e) => &mut e.base, + Expression::JSXFragment(e) => &mut e.base, + Expression::NewExpression(e) => &mut e.base, + Expression::OptionalCallExpression(e) => &mut e.base, + _ => return, + }; + base.loc = Some(loc); +} + +fn codegen_label(id: BlockId) -> String { + format!("bb{}", id.0) +} + +fn symbol_for(name: &str) -> Expression { + Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::MemberExpression(ast_expr::MemberExpression { + base: BaseNode::typed("MemberExpression"), + object: Box::new(Expression::Identifier(make_identifier("Symbol"))), + property: Box::new(Expression::Identifier(make_identifier("for"))), + computed: false, + })), + arguments: vec![Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: name.to_string(), + })], + type_parameters: None, + type_arguments: None, + optional: None, + }) +} + +fn codegen_primitive_value(value: &PrimitiveValue, loc: Option) -> Expression { + match value { + PrimitiveValue::Number(n) => { + let f = n.value(); + if f < 0.0 { + Expression::UnaryExpression(ast_expr::UnaryExpression { + base: base_node_with_loc("UnaryExpression", loc), + operator: AstUnaryOperator::Neg, + prefix: true, + argument: Box::new(Expression::NumericLiteral(NumericLiteral { + base: base_node_with_loc("NumericLiteral", loc), + value: -f, + })), + }) + } else { + Expression::NumericLiteral(NumericLiteral { + base: base_node_with_loc("NumericLiteral", loc), + value: f, + }) + } + } + PrimitiveValue::Boolean(b) => Expression::BooleanLiteral(BooleanLiteral { + base: base_node_with_loc("BooleanLiteral", loc), + value: *b, + }), + PrimitiveValue::String(s) => Expression::StringLiteral(StringLiteral { + base: base_node_with_loc("StringLiteral", loc), + value: s.clone(), + }), + PrimitiveValue::Null => Expression::NullLiteral(NullLiteral { + base: base_node_with_loc("NullLiteral", loc), + }), + PrimitiveValue::Undefined => Expression::Identifier(make_identifier("undefined")), + } +} + +fn property_literal_to_expression(prop: &PropertyLiteral) -> (Expression, bool) { + match prop { + PropertyLiteral::String(s) => (Expression::Identifier(make_identifier(s)), false), + PropertyLiteral::Number(n) => ( + Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: n.value(), + }), + true, + ), + } +} + +fn convert_value_to_expression(value: ExpressionOrJsxText) -> Expression { + match value { + ExpressionOrJsxText::Expression(e) => e, + ExpressionOrJsxText::JsxText(text) => Expression::StringLiteral(StringLiteral { + base: BaseNode::typed("StringLiteral"), + value: text.value, + }), + } +} + +fn get_instruction_value( + reactive_value: &ReactiveValue, +) -> Result<&InstructionValue, CompilerError> { + match reactive_value { + ReactiveValue::Instruction(iv) => Ok(iv), + _ => Err(invariant_err("Expected base instruction value", None)), + } +} + +fn invariant( + condition: bool, + reason: &str, + loc: Option, +) -> Result<(), CompilerError> { + if !condition { + Err(invariant_err(reason, loc)) + } else { + Ok(()) + } +} + +fn invariant_err(reason: &str, loc: Option) -> CompilerError { + // Use CompilerDiagnostic (with details array) to match TS + // CompilerError.invariant() + let mut err = CompilerError::new(); + err.push_diagnostic( + CompilerDiagnostic::new(ErrorCategory::Invariant, reason, None::).with_detail( + CompilerDiagnosticDetail::Error { + loc, + message: Some(reason.to_string()), + identifier_name: None, + }, + ), + ); + err +} + +fn invariant_err_with_detail_message( + reason: &str, + message: &str, + loc: Option, +) -> CompilerError { + let mut err = CompilerError::new(); + let diagnostic = react_compiler_diagnostics::CompilerDiagnostic::new( + ErrorCategory::Invariant, + reason, + None::, + ) + .with_detail( + react_compiler_diagnostics::CompilerDiagnosticDetail::Error { + loc, + message: Some(message.to_string()), + identifier_name: None, + }, + ); + err.push_diagnostic(diagnostic); + err +} + +fn get_statement_type_name(stmt: &Statement) -> &'static str { + match stmt { + Statement::ExpressionStatement(_) => "ExpressionStatement", + Statement::BlockStatement(_) => "BlockStatement", + Statement::VariableDeclaration(_) => "VariableDeclaration", + Statement::ReturnStatement(_) => "ReturnStatement", + Statement::IfStatement(_) => "IfStatement", + Statement::SwitchStatement(_) => "SwitchStatement", + Statement::ForStatement(_) => "ForStatement", + Statement::ForInStatement(_) => "ForInStatement", + Statement::ForOfStatement(_) => "ForOfStatement", + Statement::WhileStatement(_) => "WhileStatement", + Statement::DoWhileStatement(_) => "DoWhileStatement", + Statement::LabeledStatement(_) => "LabeledStatement", + Statement::ThrowStatement(_) => "ThrowStatement", + Statement::TryStatement(_) => "TryStatement", + Statement::BreakStatement(_) => "BreakStatement", + Statement::ContinueStatement(_) => "ContinueStatement", + Statement::FunctionDeclaration(_) => "FunctionDeclaration", + Statement::DebuggerStatement(_) => "DebuggerStatement", + Statement::EmptyStatement(_) => "EmptyStatement", + _ => "Statement", + } +} + +fn get_statement_loc(stmt: &Statement) -> Option { + let base = match stmt { + Statement::ExpressionStatement(s) => &s.base, + Statement::BlockStatement(s) => &s.base, + Statement::VariableDeclaration(s) => &s.base, + Statement::ReturnStatement(s) => &s.base, + Statement::IfStatement(s) => &s.base, + Statement::ForStatement(s) => &s.base, + Statement::ForInStatement(s) => &s.base, + Statement::ForOfStatement(s) => &s.base, + Statement::WhileStatement(s) => &s.base, + Statement::DoWhileStatement(s) => &s.base, + Statement::LabeledStatement(s) => &s.base, + Statement::ThrowStatement(s) => &s.base, + Statement::TryStatement(s) => &s.base, + Statement::SwitchStatement(s) => &s.base, + Statement::BreakStatement(s) => &s.base, + Statement::ContinueStatement(s) => &s.base, + Statement::FunctionDeclaration(s) => &s.base, + Statement::DebuggerStatement(s) => &s.base, + Statement::EmptyStatement(s) => &s.base, + _ => return None, + }; + base.loc.as_ref().map(|loc| DiagSourceLocation { + start: react_compiler_diagnostics::Position { + line: loc.start.line, + column: loc.start.column, + index: loc.start.index, + }, + end: react_compiler_diagnostics::Position { + line: loc.end.line, + column: loc.end.column, + index: loc.end.index, + }, + }) +} + +fn compare_scope_dependency( + a: &react_compiler_hir::ReactiveScopeDependency, + b: &react_compiler_hir::ReactiveScopeDependency, + env: &Environment, +) -> std::cmp::Ordering { + let a_name = dep_to_sort_key(a, env); + let b_name = dep_to_sort_key(b, env); + a_name.cmp(&b_name) +} + +fn dep_to_sort_key(dep: &react_compiler_hir::ReactiveScopeDependency, env: &Environment) -> String { + let ident = &env.identifiers[dep.identifier.0 as usize]; + let base = match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => n.clone(), + Some(react_compiler_hir::IdentifierName::Promoted(n)) => n.clone(), + None => format!("_t{}", dep.identifier.0), + }; + let mut parts = vec![base]; + for entry in &dep.path { + let prefix = if entry.optional { "?" } else { "" }; + let prop = match &entry.property { + PropertyLiteral::String(s) => s.clone(), + PropertyLiteral::Number(n) => n.value().to_string(), + }; + parts.push(format!("{prefix}{prop}")); + } + parts.join(".") +} + +fn compare_scope_declaration( + a: &react_compiler_hir::ReactiveScopeDeclaration, + b: &react_compiler_hir::ReactiveScopeDeclaration, + env: &Environment, +) -> std::cmp::Ordering { + let a_name = ident_sort_key(a.identifier, env); + let b_name = ident_sort_key(b.identifier, env); + a_name.cmp(&b_name) +} + +fn ident_sort_key(id: IdentifierId, env: &Environment) -> String { + let ident = &env.identifiers[id.0 as usize]; + match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => n.clone(), + Some(react_compiler_hir::IdentifierName::Promoted(n)) => n.clone(), + None => format!("_t{}", id.0), + } +} + +fn jsx_tag_loc(tag: &JsxTag) -> Option { + match tag { + JsxTag::Place(p) => p.loc, + JsxTag::Builtin(_) => None, + } +} + +/// Conditionally wrap a call expression in a hook guard IIFE if +/// enableEmitHookGuards is enabled and the callee is a hook. +fn maybe_wrap_hook_call( + cx: &Context<'_>, + call_expr: Expression, + callee_id: IdentifierId, +) -> Expression { + if let Some(ref guard_name) = cx.env.hook_guard_name { + if cx.env.output_mode == react_compiler_hir::environment::OutputMode::Client + && is_hook_identifier(cx, callee_id) + { + return wrap_hook_call_with_guard(guard_name, call_expr, 2, 3); + } + } + call_expr +} + +/// Check if a callee identifier refers to a hook function. +fn is_hook_identifier(cx: &Context<'_>, identifier_id: IdentifierId) -> bool { + let identifier = &cx.env.identifiers[identifier_id.0 as usize]; + let type_ = &cx.env.types[identifier.type_.0 as usize]; + cx.env + .get_hook_kind_for_type(type_) + .ok() + .flatten() + .is_some() +} + +/// Create the hook guard IIFE wrapper for a hook call expression. +/// Wraps the call in: `(function() { try { $guard(before); return callExpr; } +/// finally { $guard(after); } })()` +fn wrap_hook_call_with_guard( + guard_name: &str, + call_expr: Expression, + before: u32, + after: u32, +) -> Expression { + let guard_call = |kind: u32| -> Statement { + Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(make_identifier(guard_name))), + arguments: vec![Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: kind as f64, + })], + type_parameters: None, + type_arguments: None, + optional: None, + })), + }) + }; + + let try_stmt = Statement::TryStatement(TryStatement { + base: BaseNode::typed("TryStatement"), + block: BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![ + guard_call(before), + Statement::ReturnStatement(ReturnStatement { + base: BaseNode::typed("ReturnStatement"), + argument: Some(Box::new(call_expr)), + }), + ], + directives: Vec::new(), + }, + handler: None, + finalizer: Some(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![guard_call(after)], + directives: Vec::new(), + }), + }); + + let iife = Expression::FunctionExpression(ast_expr::FunctionExpression { + base: BaseNode::typed("FunctionExpression"), + id: None, + params: Vec::new(), + body: BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![try_stmt], + directives: Vec::new(), + }, + generator: false, + is_async: false, + return_type: None, + type_parameters: None, + }); + + Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(iife), + arguments: vec![], + type_parameters: None, + type_arguments: None, + optional: None, + }) +} + +/// Create a try/finally wrapping for the entire function body. +/// `try { $guard(before); ...body...; } finally { $guard(after); }` +fn create_function_body_hook_guard( + guard_name: &str, + body_stmts: Vec, + before: u32, + after: u32, +) -> Statement { + let guard_call = |kind: u32| -> Statement { + Statement::ExpressionStatement(ExpressionStatement { + base: BaseNode::typed("ExpressionStatement"), + expression: Box::new(Expression::CallExpression(ast_expr::CallExpression { + base: BaseNode::typed("CallExpression"), + callee: Box::new(Expression::Identifier(make_identifier(guard_name))), + arguments: vec![Expression::NumericLiteral(NumericLiteral { + base: BaseNode::typed("NumericLiteral"), + value: kind as f64, + })], + type_parameters: None, + type_arguments: None, + optional: None, + })), + }) + }; + + let mut try_body = vec![guard_call(before)]; + try_body.extend(body_stmts); + + Statement::TryStatement(TryStatement { + base: BaseNode::typed("TryStatement"), + block: BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: try_body, + directives: Vec::new(), + }, + handler: None, + finalizer: Some(BlockStatement { + base: BaseNode::typed("BlockStatement"), + body: vec![guard_call(after)], + directives: Vec::new(), + }), + }) +} diff --git a/crates/react_compiler_reactive_scopes/src/extract_scope_declarations_from_destructuring.rs b/crates/react_compiler_reactive_scopes/src/extract_scope_declarations_from_destructuring.rs new file mode 100644 index 000000000000..c99188de0fab --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/extract_scope_declarations_from_destructuring.rs @@ -0,0 +1,224 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! ExtractScopeDeclarationsFromDestructuring — handles destructuring patterns +//! where some bindings are scope declarations and others aren't. +//! +//! Corresponds to +//! `src/ReactiveScopes/ExtractScopeDeclarationsFromDestructuring.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, visitors, DeclarationId, IdentifierId, IdentifierName, + InstructionKind, InstructionValue, LValue, ParamPattern, Place, ReactiveFunction, + ReactiveInstruction, ReactiveScopeBlock, ReactiveStatement, ReactiveValue, +}; + +use crate::visitors::{transform_reactive_function, ReactiveFunctionTransform, Transformed}; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Extracts scope declarations from destructuring patterns where some bindings +/// are scope declarations and others aren't. +/// TS: `extractScopeDeclarationsFromDestructuring` +pub fn extract_scope_declarations_from_destructuring( + func: &mut ReactiveFunction, + env: &mut Environment, +) -> Result<(), react_compiler_diagnostics::CompilerError> { + let mut declared: HashSet = HashSet::new(); + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let identifier = &env.identifiers[place.identifier.0 as usize]; + declared.insert(identifier.declaration_id); + } + let mut transform = Transform { env }; + let mut state = ExtractState { declared }; + transform_reactive_function(func, &mut transform, &mut state) +} + +struct ExtractState { + declared: HashSet, +} + +struct Transform<'a> { + env: &'a mut Environment, +} + +impl<'a> ReactiveFunctionTransform for Transform<'a> { + type State = ExtractState; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut ExtractState, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + let scope_data = &self.env.scopes[scope.scope.0 as usize]; + let decl_ids: Vec = scope_data + .declarations + .iter() + .map(|(_, d)| { + let identifier = &self.env.identifiers[d.identifier.0 as usize]; + identifier.declaration_id + }) + .collect(); + for decl_id in decl_ids { + state.declared.insert(decl_id); + } + self.traverse_scope(scope, state) + } + + fn transform_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut ExtractState, + ) -> Result, react_compiler_diagnostics::CompilerError> { + self.visit_instruction(instruction, state)?; + + let mut extra_instructions: Option> = None; + + if let ReactiveValue::Instruction(InstructionValue::Destructure { + lvalue, + value: _destr_value, + loc, + }) = &mut instruction.value + { + // Check if this is a mixed destructuring (some declared, some not) + let mut reassigned: HashSet = HashSet::new(); + let mut has_declaration = false; + + for place in visitors::each_pattern_operand(&lvalue.pattern) { + let identifier = &self.env.identifiers[place.identifier.0 as usize]; + if state.declared.contains(&identifier.declaration_id) { + reassigned.insert(place.identifier); + } else { + has_declaration = true; + } + } + + if !has_declaration { + // All reassignments + lvalue.kind = InstructionKind::Reassign; + } else if !reassigned.is_empty() { + // Mixed: replace reassigned items with temporaries and emit separate + // assignments + let mut renamed: Vec<(Place, Place)> = Vec::new(); + let instr_loc = instruction.loc.clone(); + let destr_loc = loc.clone(); + + let env = &mut *self.env; // reborrow + visitors::map_pattern_operands(&mut lvalue.pattern, &mut |place: Place| { + if !reassigned.contains(&place.identifier) { + return place; + } + // Create a temporary place (matches TS clonePlaceToTemporary) + let temp_id = env.next_identifier_id(); + let decl_id = env.identifiers[temp_id.0 as usize].declaration_id; + // Copy type from original identifier to temporary + let original_type = env.identifiers[place.identifier.0 as usize].type_; + env.identifiers[temp_id.0 as usize].type_ = original_type; + // Set identifier loc to the place's source location + // (matches TS makeTemporaryIdentifier which receives place.loc) + env.identifiers[temp_id.0 as usize].loc = place.loc.clone(); + // Promote the temporary + env.identifiers[temp_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); + let temporary = Place { + identifier: temp_id, + effect: place.effect, + reactive: place.reactive, + loc: None, // GeneratedSource — matches TS createTemporaryPlace + }; + let original = place; + renamed.push((original.clone(), temporary.clone())); + temporary + }); + + // Build extra StoreLocal instructions for each renamed place + let mut extra = Vec::new(); + for (original, temporary) in renamed { + extra.push(ReactiveInstruction { + id: instruction.id, + lvalue: None, + value: ReactiveValue::Instruction(InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: original, + }, + value: temporary, + type_annotation: None, + loc: destr_loc.clone(), + }), + effects: None, + loc: instr_loc.clone(), + }); + } + extra_instructions = Some(extra); + } + } + + // Update state.declared with declarations from the instruction(s) + if let Some(ref extras) = extra_instructions { + // Process the original instruction + update_declared_from_instruction(instruction, &self.env, state); + // Process extra instructions + for extra_instr in extras { + update_declared_from_instruction(extra_instr, &self.env, state); + } + } else { + update_declared_from_instruction(instruction, &self.env, state); + } + + if let Some(extras) = extra_instructions { + // Clone the original instruction and build the replacement list + let mut all_instructions = Vec::new(); + all_instructions.push(ReactiveStatement::Instruction(instruction.clone())); + for extra in extras { + all_instructions.push(ReactiveStatement::Instruction(extra)); + } + Ok(Transformed::ReplaceMany(all_instructions)) + } else { + Ok(Transformed::Keep) + } + } +} + +fn update_declared_from_instruction( + instr: &ReactiveInstruction, + env: &Environment, + state: &mut ExtractState, +) { + if let ReactiveValue::Instruction(iv) = &instr.value { + match iv { + InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } + | InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + if lvalue.kind != InstructionKind::Reassign { + let identifier = &env.identifiers[lvalue.place.identifier.0 as usize]; + state.declared.insert(identifier.declaration_id); + } + } + InstructionValue::Destructure { lvalue, .. } => { + if lvalue.kind != InstructionKind::Reassign { + for place in visitors::each_pattern_operand(&lvalue.pattern) { + let identifier = &env.identifiers[place.identifier.0 as usize]; + state.declared.insert(identifier.declaration_id); + } + } + } + _ => {} + } + } +} diff --git a/crates/react_compiler_reactive_scopes/src/lib.rs b/crates/react_compiler_reactive_scopes/src/lib.rs new file mode 100644 index 000000000000..f1a097f7766a --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/lib.rs @@ -0,0 +1,53 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Reactive scope passes for the React Compiler. +//! +//! Converts the HIR CFG into a tree-structured `ReactiveFunction` and runs +//! scope-related transformation passes (pruning, merging, renaming, etc.). +//! +//! Corresponds to `src/ReactiveScopes/` in the TypeScript compiler. + +mod assert_scope_instructions_within_scopes; +mod assert_well_formed_break_targets; +mod build_reactive_function; +pub mod codegen_reactive_function; +mod extract_scope_declarations_from_destructuring; +mod merge_reactive_scopes_that_invalidate_together; +pub mod print_reactive_function; +mod promote_used_temporaries; +mod propagate_early_returns; +mod prune_always_invalidating_scopes; +mod prune_hoisted_contexts; +mod prune_non_escaping_scopes; +mod prune_non_reactive_dependencies; +mod prune_unused_labels; +mod prune_unused_lvalues; +mod prune_unused_scopes; +mod rename_variables; +mod stabilize_block_ids; +pub mod visitors; + +pub use assert_scope_instructions_within_scopes::assert_scope_instructions_within_scopes; +pub use assert_well_formed_break_targets::assert_well_formed_break_targets; +pub use build_reactive_function::build_reactive_function; +pub use codegen_reactive_function::codegen_function; +pub use extract_scope_declarations_from_destructuring::extract_scope_declarations_from_destructuring; +pub use merge_reactive_scopes_that_invalidate_together::merge_reactive_scopes_that_invalidate_together; +pub use print_reactive_function::debug_reactive_function; +pub use promote_used_temporaries::promote_used_temporaries; +pub use propagate_early_returns::propagate_early_returns; +pub use prune_always_invalidating_scopes::prune_always_invalidating_scopes; +pub use prune_hoisted_contexts::prune_hoisted_contexts; +pub use prune_non_escaping_scopes::prune_non_escaping_scopes; +pub use prune_non_reactive_dependencies::prune_non_reactive_dependencies; +pub use prune_unused_labels::prune_unused_labels; +pub use prune_unused_lvalues::prune_unused_lvalues; +pub use prune_unused_scopes::prune_unused_scopes; +pub use rename_variables::rename_variables; +pub use stabilize_block_ids::stabilize_block_ids; diff --git a/crates/react_compiler_reactive_scopes/src/merge_reactive_scopes_that_invalidate_together.rs b/crates/react_compiler_reactive_scopes/src/merge_reactive_scopes_that_invalidate_together.rs new file mode 100644 index 000000000000..4dfdc4794ce3 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/merge_reactive_scopes_that_invalidate_together.rs @@ -0,0 +1,563 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! MergeReactiveScopesThatInvalidateTogether — merges adjacent or nested scopes +//! that share dependencies (and thus invalidate together) to reduce memoization +//! overhead. +//! +//! Corresponds to +//! `src/ReactiveScopes/MergeReactiveScopesThatInvalidateTogether.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::CompilerError; +use react_compiler_hir::{ + environment::Environment, + object_shape::{BUILT_IN_ARRAY_ID, BUILT_IN_FUNCTION_ID, BUILT_IN_JSX_ID, BUILT_IN_OBJECT_ID}, + DeclarationId, DependencyPathEntry, EvaluationOrder, InstructionKind, InstructionValue, Place, + ReactiveBlock, ReactiveFunction, ReactiveScopeBlock, ReactiveScopeDependency, + ReactiveStatement, ReactiveValue, ScopeId, Type, +}; + +use crate::visitors::{ + transform_reactive_function, visit_reactive_function, ReactiveFunctionTransform, + ReactiveFunctionVisitor, Transformed, +}; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Merges adjacent reactive scopes that share dependencies (invalidate +/// together). TS: `mergeReactiveScopesThatInvalidateTogether` +pub fn merge_reactive_scopes_that_invalidate_together( + func: &mut ReactiveFunction, + env: &mut Environment, +) -> Result<(), CompilerError> { + // Pass 1: find last usage of each declaration + let visitor = FindLastUsageVisitor { env: &*env }; + let mut last_usage: HashMap = HashMap::new(); + visit_reactive_function(func, &visitor, &mut last_usage); + + // Pass 2+3: merge scopes + let mut transform = MergeTransform { + env, + last_usage, + temporaries: HashMap::new(), + }; + let mut state: Option> = None; + transform_reactive_function(func, &mut transform, &mut state) +} + +// ============================================================================= +// Pass 1: FindLastUsageVisitor +// ============================================================================= + +/// TS: `class FindLastUsageVisitor extends ReactiveFunctionVisitor` +struct FindLastUsageVisitor<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionVisitor for FindLastUsageVisitor<'a> { + type State = HashMap; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_place(&self, id: EvaluationOrder, place: &Place, state: &mut Self::State) { + let decl_id = self.env.identifiers[place.identifier.0 as usize].declaration_id; + let entry = state.entry(decl_id).or_insert(id); + if id > *entry { + *entry = id; + } + } +} + +// ============================================================================= +// Pass 2+3: MergeTransform +// ============================================================================= + +/// TS: `class Transform extends +/// ReactiveFunctionTransform` +struct MergeTransform<'a> { + env: &'a mut Environment, + last_usage: HashMap, + temporaries: HashMap, +} + +impl<'a> ReactiveFunctionTransform for MergeTransform<'a> { + type State = Option>; + + fn env(&self) -> &Environment { + self.env + } + + /// TS: `override transformScope(scopeBlock, state)` + fn transform_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut Self::State, + ) -> Result, CompilerError> { + let scope_deps = self.env.scopes[scope.scope.0 as usize].dependencies.clone(); + // Save parent state and recurse with this scope's deps as state + let parent_state = state.take(); + *state = Some(scope_deps.clone()); + self.visit_scope(scope, state)?; + // Restore parent state + *state = parent_state; + + // If parent has deps and they match, flatten the inner scope + if let Some(parent_deps) = state.as_ref() { + if are_equal_dependencies(parent_deps, &scope_deps, self.env) { + let instructions = std::mem::take(&mut scope.instructions); + return Ok(Transformed::ReplaceMany(instructions)); + } + } + Ok(Transformed::Keep) + } + + /// TS: `override visitBlock(block, state)` + fn visit_block( + &mut self, + block: &mut ReactiveBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + // Pass 1: traverse nested (scope flattening handled by transform_scope) + self.traverse_block(block, state)?; + // Pass 2+3: merge consecutive scopes in this block + self.merge_scopes_in_block(block)?; + Ok(()) + } +} + +impl<'a> MergeTransform<'a> { + /// Identify and merge consecutive scopes that invalidate together. + fn merge_scopes_in_block(&mut self, block: &mut ReactiveBlock) -> Result<(), CompilerError> { + // Pass 2: identify scopes for merging + struct MergedScope { + scope_id: ScopeId, + from: usize, + to: usize, + lvalues: HashSet, + } + + let mut current: Option = None; + let mut merged: Vec = Vec::new(); + + let block_len = block.len(); + for i in 0..block_len { + match &block[i] { + ReactiveStatement::Terminal(_) => { + // Don't merge across terminals + if let Some(c) = current.take() { + if c.to > c.from + 1 { + merged.push(c); + } + } + } + ReactiveStatement::PrunedScope(_) => { + // Don't merge across pruned scopes + if let Some(c) = current.take() { + if c.to > c.from + 1 { + merged.push(c); + } + } + } + ReactiveStatement::Instruction(instr) => { + match &instr.value { + ReactiveValue::Instruction(iv) => { + match iv { + InstructionValue::BinaryExpression { .. } + | InstructionValue::ComputedLoad { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::LoadLocal { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::PropertyLoad { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::UnaryExpression { .. } => { + if let Some(ref mut c) = current { + if let Some(lvalue) = &instr.lvalue { + let decl_id = self.env.identifiers + [lvalue.identifier.0 as usize] + .declaration_id; + c.lvalues.insert(decl_id); + if let InstructionValue::LoadLocal { place, .. } = iv { + let src_decl = self.env.identifiers + [place.identifier.0 as usize] + .declaration_id; + self.temporaries.insert(decl_id, src_decl); + } + } + } + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + if let Some(ref mut c) = current { + if lvalue.kind == InstructionKind::Const { + // Add the instruction lvalue (if any) + if let Some(instr_lvalue) = &instr.lvalue { + let decl_id = self.env.identifiers + [instr_lvalue.identifier.0 as usize] + .declaration_id; + c.lvalues.insert(decl_id); + } + // Add the StoreLocal's lvalue place + let store_decl = self.env.identifiers + [lvalue.place.identifier.0 as usize] + .declaration_id; + c.lvalues.insert(store_decl); + // Track temporary mapping + let value_decl = self.env.identifiers + [value.identifier.0 as usize] + .declaration_id; + let mapped = self + .temporaries + .get(&value_decl) + .copied() + .unwrap_or(value_decl); + self.temporaries.insert(store_decl, mapped); + } else { + // Non-const StoreLocal — reset + let c = current.take().unwrap(); + if c.to > c.from + 1 { + merged.push(c); + } + } + } + } + _ => { + // Other instructions prevent merging + if let Some(c) = current.take() { + if c.to > c.from + 1 { + merged.push(c); + } + } + } + } + } + _ => { + // Non-Instruction reactive values prevent merging + if let Some(c) = current.take() { + if c.to > c.from + 1 { + merged.push(c); + } + } + } + } + } + ReactiveStatement::Scope(scope_block) => { + let next_scope_id = scope_block.scope; + if let Some(ref mut c) = current { + let current_scope_id = c.scope_id; + if can_merge_scopes( + current_scope_id, + next_scope_id, + self.env, + &self.temporaries, + ) && are_lvalues_last_used_by_scope( + next_scope_id, + &c.lvalues, + &self.last_usage, + self.env, + ) { + // Merge: extend the current scope's range + let next_range_end = + self.env.scopes[next_scope_id.0 as usize].range.end; + let current_range_end = + self.env.scopes[current_scope_id.0 as usize].range.end; + self.env.scopes[current_scope_id.0 as usize].range.end = + EvaluationOrder(current_range_end.0.max(next_range_end.0)); + + // Merge declarations from next into current + let next_decls = self.env.scopes[next_scope_id.0 as usize] + .declarations + .clone(); + for (key, value) in next_decls { + let current_decls = + &mut self.env.scopes[current_scope_id.0 as usize].declarations; + if let Some(existing) = + current_decls.iter_mut().find(|(k, _)| *k == key) + { + existing.1 = value; + } else { + current_decls.push((key, value)); + } + } + + // Prune declarations that are no longer used after the merged scope + update_scope_declarations(current_scope_id, &self.last_usage, self.env); + + c.to = i + 1; + c.lvalues.clear(); + + if !scope_is_eligible_for_merging(next_scope_id, self.env) { + let c = current.take().unwrap(); + if c.to > c.from + 1 { + merged.push(c); + } + } + } else { + // Cannot merge — reset + let c = current.take().unwrap(); + if c.to > c.from + 1 { + merged.push(c); + } + // Start new candidate if eligible + if scope_is_eligible_for_merging(next_scope_id, self.env) { + current = Some(MergedScope { + scope_id: next_scope_id, + from: i, + to: i + 1, + lvalues: HashSet::new(), + }); + } + } + } else { + // No current — start new candidate if eligible + if scope_is_eligible_for_merging(next_scope_id, self.env) { + current = Some(MergedScope { + scope_id: next_scope_id, + from: i, + to: i + 1, + lvalues: HashSet::new(), + }); + } + } + } + } + } + // Flush remaining + if let Some(c) = current.take() { + if c.to > c.from + 1 { + merged.push(c); + } + } + + // Pass 3: apply merges + if merged.is_empty() { + return Ok(()); + } + + let mut next_instructions: Vec = Vec::new(); + let mut index = 0; + let all_stmts: Vec = std::mem::take(block); + + for entry in &merged { + // Push everything before the merge range + while index < entry.from { + next_instructions.push(all_stmts[index].clone()); + index += 1; + } + // The first item in the merge range must be a scope + let mut merged_scope = match &all_stmts[entry.from] { + ReactiveStatement::Scope(s) => s.clone(), + _ => { + return Err(react_compiler_diagnostics::CompilerDiagnostic::new( + react_compiler_diagnostics::ErrorCategory::Invariant, + "MergeConsecutiveScopes: Expected scope at starting index", + None, + ) + .into()); + } + }; + index += 1; + while index < entry.to { + let stmt = &all_stmts[index]; + index += 1; + match stmt { + ReactiveStatement::Scope(inner_scope) => { + merged_scope + .instructions + .extend(inner_scope.instructions.clone()); + self.env.scopes[merged_scope.scope.0 as usize] + .merged + .push(inner_scope.scope); + } + _ => { + merged_scope.instructions.push(stmt.clone()); + } + } + } + next_instructions.push(ReactiveStatement::Scope(merged_scope)); + } + // Push remaining + while index < all_stmts.len() { + next_instructions.push(all_stmts[index].clone()); + index += 1; + } + + *block = next_instructions; + Ok(()) + } +} + +// ============================================================================= +// Helper functions +// ============================================================================= + +/// Updates scope declarations to remove any that are not used after the scope. +fn update_scope_declarations( + scope_id: ScopeId, + last_usage: &HashMap, + env: &mut Environment, +) { + let range_end = env.scopes[scope_id.0 as usize].range.end; + env.scopes[scope_id.0 as usize] + .declarations + .retain(|(_id, decl)| { + let decl_declaration_id = env.identifiers[decl.identifier.0 as usize].declaration_id; + match last_usage.get(&decl_declaration_id) { + Some(last_used_at) => *last_used_at >= range_end, + // If not tracked, keep the declaration (conservative) + None => true, + } + }); +} + +/// Returns whether all lvalues are last used at or before the given scope. +fn are_lvalues_last_used_by_scope( + scope_id: ScopeId, + lvalues: &HashSet, + last_usage: &HashMap, + env: &Environment, +) -> bool { + let range_end = env.scopes[scope_id.0 as usize].range.end; + for lvalue in lvalues { + if let Some(&last_used_at) = last_usage.get(lvalue) { + if last_used_at >= range_end { + return false; + } + } + } + true +} + +/// Check if two scopes can be merged. +fn can_merge_scopes( + current_id: ScopeId, + next_id: ScopeId, + env: &Environment, + temporaries: &HashMap, +) -> bool { + let current = &env.scopes[current_id.0 as usize]; + let next = &env.scopes[next_id.0 as usize]; + + // Don't merge scopes with reassignments + if !current.reassignments.is_empty() || !next.reassignments.is_empty() { + return false; + } + + // Merge scopes whose dependencies are identical + if are_equal_dependencies(¤t.dependencies, &next.dependencies, env) { + return true; + } + + // Merge scopes where outputs of current are inputs of next + // Build synthetic dependencies from current's declarations + let current_decl_deps: Vec = current + .declarations + .iter() + .map(|(_key, decl)| ReactiveScopeDependency { + identifier: decl.identifier, + reactive: true, + path: Vec::new(), + loc: None, + }) + .collect(); + + if are_equal_dependencies(¤t_decl_deps, &next.dependencies, env) { + return true; + } + + // Check if all next deps have empty paths, always-invalidating types, + // and correspond to current declarations (possibly through temporaries) + if !next.dependencies.is_empty() + && next.dependencies.iter().all(|dep| { + if !dep.path.is_empty() { + return false; + } + let dep_type = &env.types[env.identifiers[dep.identifier.0 as usize].type_.0 as usize]; + if !is_always_invalidating_type(dep_type) { + return false; + } + let dep_decl = env.identifiers[dep.identifier.0 as usize].declaration_id; + current.declarations.iter().any(|(_key, decl)| { + let decl_decl_id = env.identifiers[decl.identifier.0 as usize].declaration_id; + decl_decl_id == dep_decl + || temporaries.get(&dep_decl).copied() == Some(decl_decl_id) + }) + }) + { + return true; + } + + false +} + +/// Check if a type is always invalidating (guaranteed to change when inputs +/// change). +pub fn is_always_invalidating_type(ty: &Type) -> bool { + match ty { + Type::Object { shape_id } => { + if let Some(id) = shape_id { + matches!( + id.as_str(), + s if s == BUILT_IN_ARRAY_ID + || s == BUILT_IN_OBJECT_ID + || s == BUILT_IN_FUNCTION_ID + || s == BUILT_IN_JSX_ID + ) + } else { + false + } + } + Type::Function { .. } => true, + _ => false, + } +} + +/// Check if two dependency lists are equal. +fn are_equal_dependencies( + a: &[ReactiveScopeDependency], + b: &[ReactiveScopeDependency], + env: &Environment, +) -> bool { + if a.len() != b.len() { + return false; + } + for a_val in a { + let a_decl = env.identifiers[a_val.identifier.0 as usize].declaration_id; + let found = b.iter().any(|b_val| { + let b_decl = env.identifiers[b_val.identifier.0 as usize].declaration_id; + a_decl == b_decl && are_equal_paths(&a_val.path, &b_val.path) + }); + if !found { + return false; + } + } + true +} + +/// Check if two dependency paths are equal. +fn are_equal_paths(a: &[DependencyPathEntry], b: &[DependencyPathEntry]) -> bool { + a.len() == b.len() + && a.iter() + .zip(b.iter()) + .all(|(ai, bi)| ai.property == bi.property && ai.optional == bi.optional) +} + +/// Check if a scope is eligible for merging with subsequent scopes. +fn scope_is_eligible_for_merging(scope_id: ScopeId, env: &Environment) -> bool { + let scope = &env.scopes[scope_id.0 as usize]; + if scope.dependencies.is_empty() { + // No dependencies means output never changes — eligible + return true; + } + scope.declarations.iter().any(|(_key, decl)| { + let ty = &env.types[env.identifiers[decl.identifier.0 as usize].type_.0 as usize]; + is_always_invalidating_type(ty) + }) +} diff --git a/crates/react_compiler_reactive_scopes/src/print_reactive_function.rs b/crates/react_compiler_reactive_scopes/src/print_reactive_function.rs new file mode 100644 index 000000000000..21e4250b700e --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/print_reactive_function.rs @@ -0,0 +1,633 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Verbose debug printer for ReactiveFunction. +//! +//! Produces output identical to the TS `printDebugReactiveFunction`. +//! Delegates shared formatting (Places, Identifiers, Scopes, Types, +//! InstructionValues, Effects, Errors) to +//! `react_compiler_hir::print::PrintFormatter`. + +use react_compiler_hir::{ + environment::Environment, + print::{self, PrintFormatter}, + HirFunction, ParamPattern, ReactiveBlock, ReactiveFunction, ReactiveInstruction, + ReactiveStatement, ReactiveTerminal, ReactiveTerminalStatement, ReactiveValue, +}; + +// ============================================================================= +// DebugPrinter — thin wrapper around PrintFormatter for reactive-specific logic +// ============================================================================= + +pub struct DebugPrinter<'a> { + pub fmt: PrintFormatter<'a>, + /// Optional formatter for HIR functions (used for inner functions in + /// FunctionExpression/ObjectMethod) + pub hir_formatter: Option<&'a HirFunctionFormatter>, +} + +impl<'a> DebugPrinter<'a> { + pub fn new(env: &'a Environment) -> Self { + Self { + fmt: PrintFormatter::new(env), + hir_formatter: None, + } + } + + // ========================================================================= + // ReactiveFunction + // ========================================================================= + + pub fn format_reactive_function(&mut self, func: &ReactiveFunction) { + self.fmt.indent(); + self.fmt.line(&format!( + "id: {}", + match &func.id { + Some(id) => format!("\"{}\"", id), + None => "null".to_string(), + } + )); + self.fmt.line(&format!( + "name_hint: {}", + match &func.name_hint { + Some(h) => format!("\"{}\"", h), + None => "null".to_string(), + } + )); + self.fmt.line(&format!("generator: {}", func.generator)); + self.fmt.line(&format!("is_async: {}", func.is_async)); + self.fmt + .line(&format!("loc: {}", print::format_loc(&func.loc))); + + // params + self.fmt.line("params:"); + self.fmt.indent(); + for (i, param) in func.params.iter().enumerate() { + match param { + ParamPattern::Place(place) => { + self.fmt.format_place_field(&format!("[{}]", i), place); + } + ParamPattern::Spread(spread) => { + self.fmt.line(&format!("[{}] Spread:", i)); + self.fmt.indent(); + self.fmt.format_place_field("place", &spread.place); + self.fmt.dedent(); + } + } + } + self.fmt.dedent(); + + // directives + self.fmt.line("directives:"); + self.fmt.indent(); + for (i, d) in func.directives.iter().enumerate() { + self.fmt.line(&format!("[{}] \"{}\"", i, d)); + } + self.fmt.dedent(); + + self.fmt.line(""); + self.fmt.line("Body:"); + self.fmt.indent(); + self.format_reactive_block(&func.body); + self.fmt.dedent(); + self.fmt.dedent(); + } + + // ========================================================================= + // ReactiveBlock + // ========================================================================= + + fn format_reactive_block(&mut self, block: &ReactiveBlock) { + for stmt in block.iter() { + self.format_reactive_statement(stmt); + } + } + + fn format_reactive_statement(&mut self, stmt: &ReactiveStatement) { + match stmt { + ReactiveStatement::Instruction(instr) => { + self.format_reactive_instruction_block(instr); + } + ReactiveStatement::Terminal(term) => { + self.fmt.line("ReactiveTerminalStatement {"); + self.fmt.indent(); + self.format_terminal_statement(term); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveStatement::Scope(scope) => { + self.fmt.line("ReactiveScopeBlock {"); + self.fmt.indent(); + self.fmt.format_scope_field("scope", scope.scope); + self.fmt.line("instructions:"); + self.fmt.indent(); + self.format_reactive_block(&scope.instructions); + self.fmt.dedent(); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveStatement::PrunedScope(scope) => { + self.fmt.line("PrunedReactiveScopeBlock {"); + self.fmt.indent(); + self.fmt.format_scope_field("scope", scope.scope); + self.fmt.line("instructions:"); + self.fmt.indent(); + self.format_reactive_block(&scope.instructions); + self.fmt.dedent(); + self.fmt.dedent(); + self.fmt.line("}"); + } + } + } + + // ========================================================================= + // ReactiveInstruction + // ========================================================================= + + fn format_reactive_instruction_block(&mut self, instr: &ReactiveInstruction) { + self.fmt.line("ReactiveInstruction {"); + self.fmt.indent(); + self.format_reactive_instruction(instr); + self.fmt.dedent(); + self.fmt.line("}"); + } + + fn format_reactive_instruction(&mut self, instr: &ReactiveInstruction) { + self.fmt.line(&format!("id: {}", instr.id.0)); + match &instr.lvalue { + Some(place) => self.fmt.format_place_field("lvalue", place), + None => self.fmt.line("lvalue: null"), + } + self.fmt.line("value:"); + self.fmt.indent(); + self.format_reactive_value(&instr.value); + self.fmt.dedent(); + match &instr.effects { + Some(effects) => { + self.fmt.line("effects:"); + self.fmt.indent(); + for (i, eff) in effects.iter().enumerate() { + self.fmt + .line(&format!("[{}] {}", i, self.fmt.format_effect(eff))); + } + self.fmt.dedent(); + } + None => self.fmt.line("effects: null"), + } + self.fmt + .line(&format!("loc: {}", print::format_loc(&instr.loc))); + } + + // ========================================================================= + // ReactiveValue + // ========================================================================= + + fn format_reactive_value(&mut self, value: &ReactiveValue) { + match value { + ReactiveValue::Instruction(iv) => { + // Build the inner function formatter callback if we have an hir_formatter + let hir_formatter = self.hir_formatter; + let inner_func_cb: Option> = + hir_formatter.map(|hf| { + Box::new(move |fmt: &mut PrintFormatter, func: &HirFunction| { + hf(fmt, func); + }) + as Box + }); + self.fmt.format_instruction_value( + iv, + inner_func_cb + .as_ref() + .map(|cb| cb.as_ref() as &dyn Fn(&mut PrintFormatter, &HirFunction)), + ); + } + ReactiveValue::LogicalExpression { + operator, + left, + right, + loc, + } => { + self.fmt.line("LogicalExpression {"); + self.fmt.indent(); + self.fmt.line(&format!("operator: \"{}\"", operator)); + self.fmt.line("left:"); + self.fmt.indent(); + self.format_reactive_value(left); + self.fmt.dedent(); + self.fmt.line("right:"); + self.fmt.indent(); + self.format_reactive_value(right); + self.fmt.dedent(); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + loc, + } => { + self.fmt.line("ConditionalExpression {"); + self.fmt.indent(); + self.fmt.line("test:"); + self.fmt.indent(); + self.format_reactive_value(test); + self.fmt.dedent(); + self.fmt.line("consequent:"); + self.fmt.indent(); + self.format_reactive_value(consequent); + self.fmt.dedent(); + self.fmt.line("alternate:"); + self.fmt.indent(); + self.format_reactive_value(alternate); + self.fmt.dedent(); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveValue::SequenceExpression { + instructions, + id, + value, + loc, + } => { + self.fmt.line("SequenceExpression {"); + self.fmt.indent(); + self.fmt.line("instructions:"); + self.fmt.indent(); + for (i, instr) in instructions.iter().enumerate() { + self.fmt.line(&format!("[{}]:", i)); + self.fmt.indent(); + self.format_reactive_instruction_block(instr); + self.fmt.dedent(); + } + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line("value:"); + self.fmt.indent(); + self.format_reactive_value(value); + self.fmt.dedent(); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveValue::OptionalExpression { + id, + value, + optional, + loc, + } => { + self.fmt.line("OptionalExpression {"); + self.fmt.indent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line("value:"); + self.fmt.indent(); + self.format_reactive_value(value); + self.fmt.dedent(); + self.fmt.line(&format!("optional: {}", optional)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + } + } + + // ========================================================================= + // ReactiveTerminal + // ========================================================================= + + fn format_terminal_statement(&mut self, stmt: &ReactiveTerminalStatement) { + match &stmt.label { + Some(label) => { + self.fmt.line(&format!( + "label: {{ id: bb{}, implicit: {} }}", + label.id.0, label.implicit + )); + } + None => self.fmt.line("label: null"), + } + self.fmt.line("terminal:"); + self.fmt.indent(); + self.format_reactive_terminal(&stmt.terminal); + self.fmt.dedent(); + } + + fn format_reactive_terminal(&mut self, terminal: &ReactiveTerminal) { + match terminal { + ReactiveTerminal::Break { + target, + id, + target_kind, + loc, + } => { + self.fmt.line("Break {"); + self.fmt.indent(); + self.fmt.line(&format!("target: bb{}", target.0)); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("targetKind: \"{}\"", target_kind)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::Continue { + target, + id, + target_kind, + loc, + } => { + self.fmt.line("Continue {"); + self.fmt.indent(); + self.fmt.line(&format!("target: bb{}", target.0)); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("targetKind: \"{}\"", target_kind)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::Return { value, id, loc } => { + self.fmt.line("Return {"); + self.fmt.indent(); + self.fmt.format_place_field("value", value); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::Throw { value, id, loc } => { + self.fmt.line("Throw {"); + self.fmt.indent(); + self.fmt.format_place_field("value", value); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::Switch { + test, + cases, + id, + loc, + } => { + self.fmt.line("Switch {"); + self.fmt.indent(); + self.fmt.format_place_field("test", test); + self.fmt.line("cases:"); + self.fmt.indent(); + for (i, case) in cases.iter().enumerate() { + self.fmt.line(&format!("[{}] {{", i)); + self.fmt.indent(); + match &case.test { + Some(p) => { + self.fmt.format_place_field("test", p); + } + None => { + self.fmt.line("test: null"); + } + } + match &case.block { + Some(block) => { + self.fmt.line("block:"); + self.fmt.indent(); + self.format_reactive_block(block); + self.fmt.dedent(); + } + None => self.fmt.line("block: undefined"), + } + self.fmt.dedent(); + self.fmt.line("}"); + } + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::DoWhile { + loop_block, + test, + id, + loc, + } => { + self.fmt.line("DoWhile {"); + self.fmt.indent(); + self.fmt.line("loop:"); + self.fmt.indent(); + self.format_reactive_block(loop_block); + self.fmt.dedent(); + self.fmt.line("test:"); + self.fmt.indent(); + self.format_reactive_value(test); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::While { + test, + loop_block, + id, + loc, + } => { + self.fmt.line("While {"); + self.fmt.indent(); + self.fmt.line("test:"); + self.fmt.indent(); + self.format_reactive_value(test); + self.fmt.dedent(); + self.fmt.line("loop:"); + self.fmt.indent(); + self.format_reactive_block(loop_block); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + id, + loc, + } => { + self.fmt.line("For {"); + self.fmt.indent(); + self.fmt.line("init:"); + self.fmt.indent(); + self.format_reactive_value(init); + self.fmt.dedent(); + self.fmt.line("test:"); + self.fmt.indent(); + self.format_reactive_value(test); + self.fmt.dedent(); + match update { + Some(u) => { + self.fmt.line("update:"); + self.fmt.indent(); + self.format_reactive_value(u); + self.fmt.dedent(); + } + None => self.fmt.line("update: null"), + } + self.fmt.line("loop:"); + self.fmt.indent(); + self.format_reactive_block(loop_block); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + id, + loc, + } => { + self.fmt.line("ForOf {"); + self.fmt.indent(); + self.fmt.line("init:"); + self.fmt.indent(); + self.format_reactive_value(init); + self.fmt.dedent(); + self.fmt.line("test:"); + self.fmt.indent(); + self.format_reactive_value(test); + self.fmt.dedent(); + self.fmt.line("loop:"); + self.fmt.indent(); + self.format_reactive_block(loop_block); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::ForIn { + init, + loop_block, + id, + loc, + } => { + self.fmt.line("ForIn {"); + self.fmt.indent(); + self.fmt.line("init:"); + self.fmt.indent(); + self.format_reactive_value(init); + self.fmt.dedent(); + self.fmt.line("loop:"); + self.fmt.indent(); + self.format_reactive_block(loop_block); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::If { + test, + consequent, + alternate, + id, + loc, + } => { + self.fmt.line("If {"); + self.fmt.indent(); + self.fmt.format_place_field("test", test); + self.fmt.line("consequent:"); + self.fmt.indent(); + self.format_reactive_block(consequent); + self.fmt.dedent(); + match alternate { + Some(alt) => { + self.fmt.line("alternate:"); + self.fmt.indent(); + self.format_reactive_block(alt); + self.fmt.dedent(); + } + None => self.fmt.line("alternate: null"), + } + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::Label { block, id, loc } => { + self.fmt.line("Label {"); + self.fmt.indent(); + self.fmt.line("block:"); + self.fmt.indent(); + self.format_reactive_block(block); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + id, + loc, + } => { + self.fmt.line("Try {"); + self.fmt.indent(); + self.fmt.line("block:"); + self.fmt.indent(); + self.format_reactive_block(block); + self.fmt.dedent(); + match handler_binding { + Some(p) => self.fmt.format_place_field("handlerBinding", p), + None => self.fmt.line("handlerBinding: null"), + } + self.fmt.line("handler:"); + self.fmt.indent(); + self.format_reactive_block(handler); + self.fmt.dedent(); + self.fmt.line(&format!("id: {}", id.0)); + self.fmt.line(&format!("loc: {}", print::format_loc(loc))); + self.fmt.dedent(); + self.fmt.line("}"); + } + } + } +} + +// ============================================================================= +// Entry point +// ============================================================================= + +/// Type alias for a function formatter callback that can print HIR functions. +/// Used to format inner functions in FunctionExpression/ObjectMethod values. +pub type HirFunctionFormatter = dyn Fn(&mut PrintFormatter, &HirFunction); + +pub fn debug_reactive_function(func: &ReactiveFunction, env: &Environment) -> String { + debug_reactive_function_with_formatter(func, env, None) +} + +pub fn debug_reactive_function_with_formatter( + func: &ReactiveFunction, + env: &Environment, + hir_formatter: Option<&HirFunctionFormatter>, +) -> String { + let mut printer = DebugPrinter::new(env); + printer.hir_formatter = hir_formatter; + printer.format_reactive_function(func); + + // TODO: Print outlined functions when they've been converted to reactive form + + printer.fmt.line(""); + printer.fmt.line("Environment:"); + printer.fmt.indent(); + printer.fmt.format_errors(&env.errors); + printer.fmt.dedent(); + + printer.fmt.to_string_output() +} diff --git a/crates/react_compiler_reactive_scopes/src/promote_used_temporaries.rs b/crates/react_compiler_reactive_scopes/src/promote_used_temporaries.rs new file mode 100644 index 000000000000..468fa0b0ceef --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/promote_used_temporaries.rs @@ -0,0 +1,1165 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PromoteUsedTemporaries — promotes temporary variables to named variables +//! if they're used by scopes. +//! +//! Corresponds to `src/ReactiveScopes/PromoteUsedTemporaries.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + environment::Environment, DeclarationId, IdentifierId, IdentifierName, InstructionKind, + InstructionValue, JsxTag, ParamPattern, Place, ReactiveBlock, ReactiveFunction, + ReactiveInstruction, ReactiveStatement, ReactiveTerminal, ReactiveTerminalStatement, + ReactiveValue, ScopeId, +}; + +// ============================================================================= +// State +// ============================================================================= + +struct State { + tags: HashSet, + promoted: HashSet, + pruned: HashMap, +} + +struct PrunedInfo { + active_scopes: Vec, + used_outside_scope: bool, +} + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Promotes temporary (unnamed) identifiers used in scopes to named +/// identifiers. TS: `promoteUsedTemporaries` +pub fn promote_used_temporaries(func: &mut ReactiveFunction, env: &mut Environment) { + let mut state = State { + tags: HashSet::new(), + promoted: HashSet::new(), + pruned: HashMap::new(), + }; + + // Phase 1: collect promotable temporaries (jsx tags, pruned scope usage) + let mut active_scopes: Vec = Vec::new(); + collect_promotable_block(&func.body, &mut state, &mut active_scopes, env); + + // Promote params + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let identifier = &env.identifiers[place.identifier.0 as usize]; + if identifier.name.is_none() { + promote_identifier(place.identifier, &mut state, env); + } + } + + // Phase 2: promote identifiers used in scopes + promote_temporaries_block(&func.body, &mut state, env); + + // Phase 3: promote interposed temporaries + let mut consts: HashSet = HashSet::new(); + let mut globals: HashSet = HashSet::new(); + for param in &func.params { + match param { + ParamPattern::Place(p) => { + consts.insert(p.identifier); + } + ParamPattern::Spread(s) => { + consts.insert(s.place.identifier); + } + } + } + let mut inter_state: HashMap = HashMap::new(); + promote_interposed_block( + &func.body, + &mut state, + &mut inter_state, + &mut consts, + &mut globals, + env, + ); + + // Phase 4: promote all instances of promoted declaration IDs + promote_all_instances_params(func, &mut state, env); + promote_all_instances_block(&func.body, &mut state, env); +} + +// ============================================================================= +// Phase 1: CollectPromotableTemporaries +// ============================================================================= + +fn collect_promotable_block( + block: &ReactiveBlock, + state: &mut State, + active_scopes: &mut Vec, + env: &Environment, +) { + for stmt in block { + match stmt { + ReactiveStatement::Instruction(instr) => { + collect_promotable_instruction(instr, state, active_scopes, env); + } + ReactiveStatement::Scope(scope) => { + let scope_id = scope.scope; + active_scopes.push(scope_id); + collect_promotable_block(&scope.instructions, state, active_scopes, env); + active_scopes.pop(); + } + ReactiveStatement::PrunedScope(scope) => { + let scope_data = &env.scopes[scope.scope.0 as usize]; + for (_id, decl) in &scope_data.declarations { + let identifier = &env.identifiers[decl.identifier.0 as usize]; + state.pruned.insert( + identifier.declaration_id, + PrunedInfo { + active_scopes: active_scopes.clone(), + used_outside_scope: false, + }, + ); + } + collect_promotable_block(&scope.instructions, state, active_scopes, env); + } + ReactiveStatement::Terminal(terminal) => { + collect_promotable_terminal(terminal, state, active_scopes, env); + } + } + } +} + +fn collect_promotable_place( + place: &Place, + state: &mut State, + active_scopes: &[ScopeId], + env: &Environment, +) { + if !active_scopes.is_empty() { + let identifier = &env.identifiers[place.identifier.0 as usize]; + if let Some(pruned) = state.pruned.get_mut(&identifier.declaration_id) { + if let Some(last) = active_scopes.last() { + if !pruned.active_scopes.contains(last) { + pruned.used_outside_scope = true; + } + } + } + } +} + +fn collect_promotable_instruction( + instr: &ReactiveInstruction, + state: &mut State, + active_scopes: &mut Vec, + env: &Environment, +) { + collect_promotable_value(&instr.value, state, active_scopes, env); +} + +fn collect_promotable_value( + value: &ReactiveValue, + state: &mut State, + active_scopes: &mut Vec, + env: &Environment, +) { + match value { + ReactiveValue::Instruction(instr_value) => { + // Visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(instr_value, env) + { + collect_promotable_place(&place, state, active_scopes, env); + } + // Check for JSX tag + if let InstructionValue::JsxExpression { + tag: JsxTag::Place(place), + .. + } = instr_value + { + let identifier = &env.identifiers[place.identifier.0 as usize]; + state.tags.insert(identifier.declaration_id); + } + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions { + collect_promotable_instruction(instr, state, active_scopes, env); + } + collect_promotable_value(inner, state, active_scopes, env); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + collect_promotable_value(test, state, active_scopes, env); + collect_promotable_value(consequent, state, active_scopes, env); + collect_promotable_value(alternate, state, active_scopes, env); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + collect_promotable_value(left, state, active_scopes, env); + collect_promotable_value(right, state, active_scopes, env); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + collect_promotable_value(inner, state, active_scopes, env); + } + } +} + +fn collect_promotable_terminal( + stmt: &ReactiveTerminalStatement, + state: &mut State, + active_scopes: &mut Vec, + env: &Environment, +) { + match &stmt.terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { value, .. } | ReactiveTerminal::Throw { value, .. } => { + collect_promotable_place(value, state, active_scopes, env); + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + .. + } => { + collect_promotable_value(init, state, active_scopes, env); + collect_promotable_value(test, state, active_scopes, env); + collect_promotable_block(loop_block, state, active_scopes, env); + if let Some(update) = update { + collect_promotable_value(update, state, active_scopes, env); + } + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + .. + } => { + collect_promotable_value(init, state, active_scopes, env); + collect_promotable_value(test, state, active_scopes, env); + collect_promotable_block(loop_block, state, active_scopes, env); + } + ReactiveTerminal::ForIn { + init, loop_block, .. + } => { + collect_promotable_value(init, state, active_scopes, env); + collect_promotable_block(loop_block, state, active_scopes, env); + } + ReactiveTerminal::DoWhile { + loop_block, test, .. + } => { + collect_promotable_block(loop_block, state, active_scopes, env); + collect_promotable_value(test, state, active_scopes, env); + } + ReactiveTerminal::While { + test, loop_block, .. + } => { + collect_promotable_value(test, state, active_scopes, env); + collect_promotable_block(loop_block, state, active_scopes, env); + } + ReactiveTerminal::If { + test, + consequent, + alternate, + .. + } => { + collect_promotable_place(test, state, active_scopes, env); + collect_promotable_block(consequent, state, active_scopes, env); + if let Some(alt) = alternate { + collect_promotable_block(alt, state, active_scopes, env); + } + } + ReactiveTerminal::Switch { test, cases, .. } => { + collect_promotable_place(test, state, active_scopes, env); + for case in cases { + if let Some(t) = &case.test { + collect_promotable_place(t, state, active_scopes, env); + } + if let Some(block) = &case.block { + collect_promotable_block(block, state, active_scopes, env); + } + } + } + ReactiveTerminal::Label { block, .. } => { + collect_promotable_block(block, state, active_scopes, env); + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + .. + } => { + collect_promotable_block(block, state, active_scopes, env); + if let Some(binding) = handler_binding { + collect_promotable_place(binding, state, active_scopes, env); + } + collect_promotable_block(handler, state, active_scopes, env); + } + } +} + +// ============================================================================= +// Phase 2: PromoteTemporaries +// ============================================================================= + +fn promote_temporaries_block(block: &ReactiveBlock, state: &mut State, env: &mut Environment) { + for stmt in block { + match stmt { + ReactiveStatement::Instruction(instr) => { + promote_temporaries_value(&instr.value, state, env); + } + ReactiveStatement::Scope(scope) => { + let scope_id = scope.scope; + let scope_data = &env.scopes[scope_id.0 as usize]; + // Collect all IDs to promote first + let mut ids_to_check: Vec = Vec::new(); + ids_to_check.extend(scope_data.dependencies.iter().map(|d| d.identifier)); + ids_to_check.extend(scope_data.declarations.iter().map(|(_, d)| d.identifier)); + for id in ids_to_check { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() { + promote_identifier(id, state, env); + } + } + promote_temporaries_block(&scope.instructions, state, env); + } + ReactiveStatement::PrunedScope(scope) => { + let scope_id = scope.scope; + let scope_data = &env.scopes[scope_id.0 as usize]; + let decls: Vec<(IdentifierId, DeclarationId)> = scope_data + .declarations + .iter() + .map(|(_, d)| { + let identifier = &env.identifiers[d.identifier.0 as usize]; + (d.identifier, identifier.declaration_id) + }) + .collect(); + for (id, decl_id) in decls { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() { + if let Some(pruned) = state.pruned.get(&decl_id) { + if pruned.used_outside_scope { + promote_identifier(id, state, env); + } + } + } + } + promote_temporaries_block(&scope.instructions, state, env); + } + ReactiveStatement::Terminal(terminal) => { + promote_temporaries_terminal(terminal, state, env); + } + } + } +} + +fn promote_temporaries_value(value: &ReactiveValue, state: &mut State, env: &mut Environment) { + match value { + ReactiveValue::Instruction(instr_value) => { + // Visit inner functions + match instr_value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let func_id = lowered_func.func; + let inner_func = &env.functions[func_id.0 as usize]; + // Collect param IDs first to avoid borrow conflict + let param_ids: Vec = inner_func + .params + .iter() + .map(|param| match param { + ParamPattern::Place(p) => p.identifier, + ParamPattern::Spread(s) => s.place.identifier, + }) + .collect(); + for id in param_ids { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() { + promote_identifier(id, state, env); + } + } + } + _ => {} + } + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions { + promote_temporaries_value(&instr.value, state, env); + } + promote_temporaries_value(inner, state, env); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + promote_temporaries_value(test, state, env); + promote_temporaries_value(consequent, state, env); + promote_temporaries_value(alternate, state, env); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + promote_temporaries_value(left, state, env); + promote_temporaries_value(right, state, env); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + promote_temporaries_value(inner, state, env); + } + } +} + +fn promote_temporaries_terminal( + stmt: &ReactiveTerminalStatement, + state: &mut State, + env: &mut Environment, +) { + match &stmt.terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { .. } | ReactiveTerminal::Throw { .. } => {} + ReactiveTerminal::For { + init, + test, + update, + loop_block, + .. + } => { + promote_temporaries_value(init, state, env); + promote_temporaries_value(test, state, env); + promote_temporaries_block(loop_block, state, env); + if let Some(update) = update { + promote_temporaries_value(update, state, env); + } + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + .. + } => { + promote_temporaries_value(init, state, env); + promote_temporaries_value(test, state, env); + promote_temporaries_block(loop_block, state, env); + } + ReactiveTerminal::ForIn { + init, loop_block, .. + } => { + promote_temporaries_value(init, state, env); + promote_temporaries_block(loop_block, state, env); + } + ReactiveTerminal::DoWhile { + loop_block, test, .. + } => { + promote_temporaries_block(loop_block, state, env); + promote_temporaries_value(test, state, env); + } + ReactiveTerminal::While { + test, loop_block, .. + } => { + promote_temporaries_value(test, state, env); + promote_temporaries_block(loop_block, state, env); + } + ReactiveTerminal::If { + consequent, + alternate, + .. + } => { + promote_temporaries_block(consequent, state, env); + if let Some(alt) = alternate { + promote_temporaries_block(alt, state, env); + } + } + ReactiveTerminal::Switch { cases, .. } => { + for case in cases { + if let Some(block) = &case.block { + promote_temporaries_block(block, state, env); + } + } + } + ReactiveTerminal::Label { block, .. } => { + promote_temporaries_block(block, state, env); + } + ReactiveTerminal::Try { block, handler, .. } => { + promote_temporaries_block(block, state, env); + promote_temporaries_block(handler, state, env); + } + } +} + +// ============================================================================= +// Phase 3: PromoteInterposedTemporaries +// ============================================================================= + +fn promote_interposed_block( + block: &ReactiveBlock, + state: &mut State, + inter_state: &mut HashMap, + consts: &mut HashSet, + globals: &mut HashSet, + env: &mut Environment, +) { + for stmt in block { + match stmt { + ReactiveStatement::Instruction(instr) => { + promote_interposed_instruction(instr, state, inter_state, consts, globals, env); + } + ReactiveStatement::Scope(scope) => { + promote_interposed_block( + &scope.instructions, + state, + inter_state, + consts, + globals, + env, + ); + } + ReactiveStatement::PrunedScope(scope) => { + promote_interposed_block( + &scope.instructions, + state, + inter_state, + consts, + globals, + env, + ); + } + ReactiveStatement::Terminal(terminal) => { + promote_interposed_terminal(terminal, state, inter_state, consts, globals, env); + } + } + } +} + +fn promote_interposed_place( + place: &Place, + state: &mut State, + inter_state: &mut HashMap, + consts: &HashSet, + env: &mut Environment, +) { + if let Some(&(id, needs_promotion)) = inter_state.get(&place.identifier) { + let identifier = &env.identifiers[id.0 as usize]; + if needs_promotion && identifier.name.is_none() && !consts.contains(&id) { + promote_identifier(id, state, env); + } + } +} + +fn promote_interposed_instruction( + instr: &ReactiveInstruction, + state: &mut State, + inter_state: &mut HashMap, + consts: &mut HashSet, + globals: &mut HashSet, + env: &mut Environment, +) { + // Check instruction value lvalues (assignment targets) + match &instr.value { + ReactiveValue::Instruction(iv) => { + // Check eachInstructionValueLValue: these should all be named + // (the TS pass asserts this but we just skip in Rust) + + match iv { + InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } + | InstructionValue::Await { .. } + | InstructionValue::PropertyStore { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::ComputedStore { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::PostfixUpdate { .. } + | InstructionValue::PrefixUpdate { .. } + | InstructionValue::StoreLocal { .. } + | InstructionValue::StoreContext { .. } + | InstructionValue::StoreGlobal { .. } + | InstructionValue::Destructure { .. } => { + let mut const_store = false; + + match iv { + InstructionValue::StoreContext { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + if lvalue.kind == InstructionKind::Const + || lvalue.kind == InstructionKind::HoistedConst + { + consts.insert(lvalue.place.identifier); + const_store = true; + } + } + _ => {} + } + if let InstructionValue::Destructure { lvalue, .. } = iv { + if lvalue.kind == InstructionKind::Const + || lvalue.kind == InstructionKind::HoistedConst + { + for operand in + react_compiler_hir::visitors::each_pattern_operand(&lvalue.pattern) + { + consts.insert(operand.identifier); + } + const_store = true; + } + } + if let InstructionValue::MethodCall { property, .. } = iv { + consts.insert(property.identifier); + } + + // Visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(iv, env) + { + promote_interposed_place(&place, state, inter_state, consts, env); + } + + if !const_store + && (instr.lvalue.is_none() + || env.identifiers + [instr.lvalue.as_ref().unwrap().identifier.0 as usize] + .name + .is_some()) + { + // Mark all tracked temporaries as needing promotion + let keys: Vec = inter_state.keys().cloned().collect(); + for key in keys { + if let Some(entry) = inter_state.get_mut(&key) { + entry.1 = true; + } + } + } + if let Some(lvalue) = &instr.lvalue { + let identifier = &env.identifiers[lvalue.identifier.0 as usize]; + if identifier.name.is_none() { + inter_state.insert(lvalue.identifier, (lvalue.identifier, false)); + } + } + } + InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::DeclareLocal { lvalue, .. } => { + if lvalue.kind == InstructionKind::Const + || lvalue.kind == InstructionKind::HoistedConst + { + consts.insert(lvalue.place.identifier); + } + // Visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(iv, env) + { + promote_interposed_place(&place, state, inter_state, consts, env); + } + } + InstructionValue::LoadContext { + place: load_place, .. + } + | InstructionValue::LoadLocal { + place: load_place, .. + } => { + if let Some(lvalue) = &instr.lvalue { + let identifier = &env.identifiers[lvalue.identifier.0 as usize]; + if identifier.name.is_none() { + if consts.contains(&load_place.identifier) { + consts.insert(lvalue.identifier); + } + inter_state.insert(lvalue.identifier, (lvalue.identifier, false)); + } + } + // Visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(iv, env) + { + promote_interposed_place(&place, state, inter_state, consts, env); + } + } + InstructionValue::PropertyLoad { object, .. } + | InstructionValue::ComputedLoad { object, .. } => { + if let Some(lvalue) = &instr.lvalue { + if globals.contains(&object.identifier) { + globals.insert(lvalue.identifier); + consts.insert(lvalue.identifier); + } + let identifier = &env.identifiers[lvalue.identifier.0 as usize]; + if identifier.name.is_none() { + inter_state.insert(lvalue.identifier, (lvalue.identifier, false)); + } + } + // Visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(iv, env) + { + promote_interposed_place(&place, state, inter_state, consts, env); + } + } + InstructionValue::LoadGlobal { .. } => { + if let Some(lvalue) = &instr.lvalue { + globals.insert(lvalue.identifier); + } + // Visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(iv, env) + { + promote_interposed_place(&place, state, inter_state, consts, env); + } + } + _ => { + // Default: visit operands + for place in + react_compiler_hir::visitors::each_instruction_value_operand(iv, env) + { + promote_interposed_place(&place, state, inter_state, consts, env); + } + } + } + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for sub_instr in instructions { + promote_interposed_instruction(sub_instr, state, inter_state, consts, globals, env); + } + promote_interposed_value(inner, state, inter_state, consts, globals, env); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + promote_interposed_value(test, state, inter_state, consts, globals, env); + promote_interposed_value(consequent, state, inter_state, consts, globals, env); + promote_interposed_value(alternate, state, inter_state, consts, globals, env); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + promote_interposed_value(left, state, inter_state, consts, globals, env); + promote_interposed_value(right, state, inter_state, consts, globals, env); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + promote_interposed_value(inner, state, inter_state, consts, globals, env); + } + } +} + +fn promote_interposed_value( + value: &ReactiveValue, + state: &mut State, + inter_state: &mut HashMap, + consts: &mut HashSet, + globals: &mut HashSet, + env: &mut Environment, +) { + match value { + ReactiveValue::Instruction(iv) => { + for place in react_compiler_hir::visitors::each_instruction_value_operand(iv, env) { + promote_interposed_place(&place, state, inter_state, consts, env); + } + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions { + promote_interposed_instruction(instr, state, inter_state, consts, globals, env); + } + promote_interposed_value(inner, state, inter_state, consts, globals, env); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + promote_interposed_value(test, state, inter_state, consts, globals, env); + promote_interposed_value(consequent, state, inter_state, consts, globals, env); + promote_interposed_value(alternate, state, inter_state, consts, globals, env); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + promote_interposed_value(left, state, inter_state, consts, globals, env); + promote_interposed_value(right, state, inter_state, consts, globals, env); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + promote_interposed_value(inner, state, inter_state, consts, globals, env); + } + } +} + +fn promote_interposed_terminal( + stmt: &ReactiveTerminalStatement, + state: &mut State, + inter_state: &mut HashMap, + consts: &mut HashSet, + globals: &mut HashSet, + env: &mut Environment, +) { + match &stmt.terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { value, .. } | ReactiveTerminal::Throw { value, .. } => { + promote_interposed_place(value, state, inter_state, consts, env); + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + .. + } => { + promote_interposed_value(init, state, inter_state, consts, globals, env); + promote_interposed_value(test, state, inter_state, consts, globals, env); + promote_interposed_block(loop_block, state, inter_state, consts, globals, env); + if let Some(update) = update { + promote_interposed_value(update, state, inter_state, consts, globals, env); + } + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + .. + } => { + promote_interposed_value(init, state, inter_state, consts, globals, env); + promote_interposed_value(test, state, inter_state, consts, globals, env); + promote_interposed_block(loop_block, state, inter_state, consts, globals, env); + } + ReactiveTerminal::ForIn { + init, loop_block, .. + } => { + promote_interposed_value(init, state, inter_state, consts, globals, env); + promote_interposed_block(loop_block, state, inter_state, consts, globals, env); + } + ReactiveTerminal::DoWhile { + loop_block, test, .. + } => { + promote_interposed_block(loop_block, state, inter_state, consts, globals, env); + promote_interposed_value(test, state, inter_state, consts, globals, env); + } + ReactiveTerminal::While { + test, loop_block, .. + } => { + promote_interposed_value(test, state, inter_state, consts, globals, env); + promote_interposed_block(loop_block, state, inter_state, consts, globals, env); + } + ReactiveTerminal::If { + test, + consequent, + alternate, + .. + } => { + promote_interposed_place(test, state, inter_state, consts, env); + promote_interposed_block(consequent, state, inter_state, consts, globals, env); + if let Some(alt) = alternate { + promote_interposed_block(alt, state, inter_state, consts, globals, env); + } + } + ReactiveTerminal::Switch { test, cases, .. } => { + promote_interposed_place(test, state, inter_state, consts, env); + for case in cases { + if let Some(t) = &case.test { + promote_interposed_place(t, state, inter_state, consts, env); + } + if let Some(block) = &case.block { + promote_interposed_block(block, state, inter_state, consts, globals, env); + } + } + } + ReactiveTerminal::Label { block, .. } => { + promote_interposed_block(block, state, inter_state, consts, globals, env); + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + .. + } => { + promote_interposed_block(block, state, inter_state, consts, globals, env); + if let Some(binding) = handler_binding { + promote_interposed_place(binding, state, inter_state, consts, env); + } + promote_interposed_block(handler, state, inter_state, consts, globals, env); + } + } +} + +// ============================================================================= +// Phase 4: PromoteAllInstancesOfPromotedTemporaries +// ============================================================================= + +fn promote_all_instances_params(func: &ReactiveFunction, state: &mut State, env: &mut Environment) { + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let identifier = &env.identifiers[place.identifier.0 as usize]; + if identifier.name.is_none() && state.promoted.contains(&identifier.declaration_id) { + promote_identifier(place.identifier, state, env); + } + } +} + +fn promote_all_instances_block(block: &ReactiveBlock, state: &mut State, env: &mut Environment) { + for stmt in block { + match stmt { + ReactiveStatement::Instruction(instr) => { + promote_all_instances_instruction(instr, state, env); + } + ReactiveStatement::Scope(scope) => { + promote_all_instances_block(&scope.instructions, state, env); + promote_all_instances_scope_identifiers(scope.scope, state, env); + } + ReactiveStatement::PrunedScope(scope) => { + promote_all_instances_block(&scope.instructions, state, env); + promote_all_instances_scope_identifiers(scope.scope, state, env); + } + ReactiveStatement::Terminal(terminal) => { + promote_all_instances_terminal(terminal, state, env); + } + } + } +} + +fn promote_all_instances_scope_identifiers( + scope_id: ScopeId, + state: &mut State, + env: &mut Environment, +) { + let scope_data = &env.scopes[scope_id.0 as usize]; + + // Collect identifiers to promote + let decl_ids: Vec = scope_data + .declarations + .iter() + .map(|(_, d)| d.identifier) + .collect(); + let dep_ids: Vec = scope_data + .dependencies + .iter() + .map(|d| d.identifier) + .collect(); + let reassign_ids: Vec = scope_data.reassignments.clone(); + + for id in decl_ids { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() && state.promoted.contains(&identifier.declaration_id) { + promote_identifier(id, state, env); + } + } + for id in dep_ids { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() && state.promoted.contains(&identifier.declaration_id) { + promote_identifier(id, state, env); + } + } + for id in reassign_ids { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() && state.promoted.contains(&identifier.declaration_id) { + promote_identifier(id, state, env); + } + } +} + +fn promote_all_instances_place(place: &Place, state: &mut State, env: &mut Environment) { + let identifier = &env.identifiers[place.identifier.0 as usize]; + if identifier.name.is_none() && state.promoted.contains(&identifier.declaration_id) { + promote_identifier(place.identifier, state, env); + } +} + +fn promote_all_instances_instruction( + instr: &ReactiveInstruction, + state: &mut State, + env: &mut Environment, +) { + if let Some(lvalue) = &instr.lvalue { + promote_all_instances_place(lvalue, state, env); + } + promote_all_instances_value(&instr.value, state, env); +} + +fn promote_all_instances_value(value: &ReactiveValue, state: &mut State, env: &mut Environment) { + match value { + ReactiveValue::Instruction(iv) => { + for place in react_compiler_hir::visitors::each_instruction_value_operand(iv, env) { + promote_all_instances_place(&place, state, env); + } + // Visit inner functions + match iv { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let func_id = lowered_func.func; + let inner_func = &env.functions[func_id.0 as usize]; + let param_ids: Vec = inner_func + .params + .iter() + .map(|p| match p { + ParamPattern::Place(p) => p.identifier, + ParamPattern::Spread(s) => s.place.identifier, + }) + .collect(); + for id in param_ids { + let identifier = &env.identifiers[id.0 as usize]; + if identifier.name.is_none() + && state.promoted.contains(&identifier.declaration_id) + { + promote_identifier(id, state, env); + } + } + } + _ => {} + } + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions { + promote_all_instances_instruction(instr, state, env); + } + promote_all_instances_value(inner, state, env); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + promote_all_instances_value(test, state, env); + promote_all_instances_value(consequent, state, env); + promote_all_instances_value(alternate, state, env); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + promote_all_instances_value(left, state, env); + promote_all_instances_value(right, state, env); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + promote_all_instances_value(inner, state, env); + } + } +} + +fn promote_all_instances_terminal( + stmt: &ReactiveTerminalStatement, + state: &mut State, + env: &mut Environment, +) { + match &stmt.terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { value, .. } | ReactiveTerminal::Throw { value, .. } => { + promote_all_instances_place(value, state, env); + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + .. + } => { + promote_all_instances_value(init, state, env); + promote_all_instances_value(test, state, env); + promote_all_instances_block(loop_block, state, env); + if let Some(update) = update { + promote_all_instances_value(update, state, env); + } + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + .. + } => { + promote_all_instances_value(init, state, env); + promote_all_instances_value(test, state, env); + promote_all_instances_block(loop_block, state, env); + } + ReactiveTerminal::ForIn { + init, loop_block, .. + } => { + promote_all_instances_value(init, state, env); + promote_all_instances_block(loop_block, state, env); + } + ReactiveTerminal::DoWhile { + loop_block, test, .. + } => { + promote_all_instances_block(loop_block, state, env); + promote_all_instances_value(test, state, env); + } + ReactiveTerminal::While { + test, loop_block, .. + } => { + promote_all_instances_value(test, state, env); + promote_all_instances_block(loop_block, state, env); + } + ReactiveTerminal::If { + test, + consequent, + alternate, + .. + } => { + promote_all_instances_place(test, state, env); + promote_all_instances_block(consequent, state, env); + if let Some(alt) = alternate { + promote_all_instances_block(alt, state, env); + } + } + ReactiveTerminal::Switch { test, cases, .. } => { + promote_all_instances_place(test, state, env); + for case in cases { + if let Some(t) = &case.test { + promote_all_instances_place(t, state, env); + } + if let Some(block) = &case.block { + promote_all_instances_block(block, state, env); + } + } + } + ReactiveTerminal::Label { block, .. } => { + promote_all_instances_block(block, state, env); + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + .. + } => { + promote_all_instances_block(block, state, env); + if let Some(binding) = handler_binding { + promote_all_instances_place(binding, state, env); + } + promote_all_instances_block(handler, state, env); + } + } +} + +// ============================================================================= +// Helpers +// ============================================================================= + +fn promote_identifier(identifier_id: IdentifierId, state: &mut State, env: &mut Environment) { + let identifier = &env.identifiers[identifier_id.0 as usize]; + assert!( + identifier.name.is_none(), + "promoteTemporary: Expected to be called only for temporary variables" + ); + let decl_id = identifier.declaration_id; + if state.tags.contains(&decl_id) { + // JSX tag temporary: use capitalized name + env.identifiers[identifier_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#T{}", decl_id.0))); + } else { + env.identifiers[identifier_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); + } + state.promoted.insert(decl_id); +} diff --git a/crates/react_compiler_reactive_scopes/src/propagate_early_returns.rs b/crates/react_compiler_reactive_scopes/src/propagate_early_returns.rs new file mode 100644 index 000000000000..2ee1c039f70c --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/propagate_early_returns.rs @@ -0,0 +1,366 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PropagateEarlyReturns — ensures reactive blocks honor early return +//! semantics. +//! +//! When a scope contains an early return, creates a sentinel-based check so +//! that cached scopes can properly replay the early return behavior. +//! +//! Corresponds to `src/ReactiveScopes/PropagateEarlyReturns.ts`. + +use react_compiler_hir::{ + environment::Environment, BlockId, Effect, EvaluationOrder, IdentifierId, IdentifierName, + InstructionKind, InstructionValue, LValue, NonLocalBinding, Place, PlaceOrSpread, + PrimitiveValue, PropertyLiteral, ReactiveFunction, ReactiveInstruction, ReactiveLabel, + ReactiveScopeBlock, ReactiveScopeDeclaration, ReactiveScopeEarlyReturn, ReactiveStatement, + ReactiveTerminal, ReactiveTerminalStatement, ReactiveTerminalTargetKind, ReactiveValue, +}; + +use crate::visitors::{transform_reactive_function, ReactiveFunctionTransform, Transformed}; + +/// The sentinel string used to detect early returns. +/// TS: `EARLY_RETURN_SENTINEL` from CodegenReactiveFunction. +const EARLY_RETURN_SENTINEL: &str = "react.early_return_sentinel"; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Propagate early return semantics through reactive scopes. +/// TS: `propagateEarlyReturns` +pub fn propagate_early_returns(func: &mut ReactiveFunction, env: &mut Environment) { + let mut transform = Transform { env }; + let mut state = State { + within_reactive_scope: false, + early_return_value: None, + }; + // The TS version doesn't produce errors from this pass, so we ignore the + // Result. + let _ = transform_reactive_function(func, &mut transform, &mut state); +} + +// ============================================================================= +// State +// ============================================================================= + +#[derive(Debug, Clone)] +struct EarlyReturnInfo { + value: IdentifierId, + loc: Option, + label: BlockId, +} + +struct State { + within_reactive_scope: bool, + early_return_value: Option, +} + +// ============================================================================= +// Transform implementation (ReactiveFunctionTransform) +// ============================================================================= + +/// TS: `class Transform extends ReactiveFunctionTransform` +struct Transform<'a> { + env: &'a mut Environment, +} + +impl<'a> ReactiveFunctionTransform for Transform<'a> { + type State = State; + + fn env(&self) -> &Environment { + self.env + } + + /// TS: `override visitScope` + fn visit_scope( + &mut self, + scope_block: &mut ReactiveScopeBlock, + parent_state: &mut State, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + let scope_id = scope_block.scope; + + // Exit early if an earlier pass has already created an early return + if self.env.scopes[scope_id.0 as usize] + .early_return_value + .is_some() + { + return Ok(()); + } + + let mut inner_state = State { + within_reactive_scope: true, + early_return_value: parent_state.early_return_value.clone(), + }; + self.traverse_scope(scope_block, &mut inner_state)?; + + if let Some(early_return_value) = inner_state.early_return_value { + if !parent_state.within_reactive_scope { + // This is the outermost scope wrapping an early return + apply_early_return_to_scope(scope_block, self.env, &early_return_value); + } else { + // Not outermost — bubble up + parent_state.early_return_value = Some(early_return_value); + } + } + + Ok(()) + } + + /// TS: `override transformTerminal` + fn transform_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut State, + ) -> Result, react_compiler_diagnostics::CompilerError> { + if state.within_reactive_scope { + if let ReactiveTerminal::Return { value, .. } = &stmt.terminal { + let loc = value.loc; + + let early_return_value = if let Some(ref existing) = state.early_return_value { + existing.clone() + } else { + // Create a new early return identifier + let identifier_id = create_temporary_place_id(self.env, loc); + promote_temporary(self.env, identifier_id); + let label = self.env.next_block_id(); + EarlyReturnInfo { + value: identifier_id, + loc, + label, + } + }; + + state.early_return_value = Some(early_return_value.clone()); + + let return_value = value.clone(); + + return Ok(Transformed::ReplaceMany(vec![ + // StoreLocal: reassign the early return value + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: None, + value: ReactiveValue::Instruction(InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Reassign, + place: Place { + identifier: early_return_value.value, + effect: Effect::Capture, + reactive: true, + loc, + }, + }, + value: return_value, + type_annotation: None, + loc, + }), + effects: None, + loc, + }), + // Break to the label + ReactiveStatement::Terminal(ReactiveTerminalStatement { + terminal: ReactiveTerminal::Break { + target: early_return_value.label, + id: EvaluationOrder(0), + target_kind: ReactiveTerminalTargetKind::Labeled, + loc, + }, + label: None, + }), + ])); + } + } + + // Default: traverse into the terminal's sub-blocks + self.visit_terminal(stmt, state)?; + Ok(Transformed::Keep) + } +} + +// ============================================================================= +// Apply early return transformation to the outermost scope +// ============================================================================= + +fn apply_early_return_to_scope( + scope_block: &mut ReactiveScopeBlock, + env: &mut Environment, + early_return: &EarlyReturnInfo, +) { + let scope_id = scope_block.scope; + let loc = early_return.loc; + + // Set early return value on the scope + env.scopes[scope_id.0 as usize].early_return_value = Some(ReactiveScopeEarlyReturn { + value: early_return.value, + loc: early_return.loc, + label: early_return.label, + }); + + // Add the early return identifier as a scope declaration + env.scopes[scope_id.0 as usize].declarations.push(( + early_return.value, + ReactiveScopeDeclaration { + identifier: early_return.value, + scope: scope_id, + }, + )); + + // Create temporary places for the sentinel initialization + let sentinel_temp = create_temporary_place_id(env, loc); + let symbol_temp = create_temporary_place_id(env, loc); + let for_temp = create_temporary_place_id(env, loc); + let arg_temp = create_temporary_place_id(env, loc); + + let original_instructions = std::mem::take(&mut scope_block.instructions); + + scope_block.instructions = vec![ + // LoadGlobal Symbol + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: Some(Place { + identifier: symbol_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }), + value: ReactiveValue::Instruction(InstructionValue::LoadGlobal { + binding: NonLocalBinding::Global { + name: "Symbol".to_string(), + }, + loc, + }), + effects: None, + loc, + }), + // PropertyLoad Symbol.for + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: Some(Place { + identifier: for_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }), + value: ReactiveValue::Instruction(InstructionValue::PropertyLoad { + object: Place { + identifier: symbol_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }, + property: PropertyLiteral::String("for".to_string()), + loc, + }), + effects: None, + loc, + }), + // Primitive: the sentinel string + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: Some(Place { + identifier: arg_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }), + value: ReactiveValue::Instruction(InstructionValue::Primitive { + value: PrimitiveValue::String(EARLY_RETURN_SENTINEL.to_string()), + loc, + }), + effects: None, + loc, + }), + // MethodCall: Symbol.for("react.early_return_sentinel") + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: Some(Place { + identifier: sentinel_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }), + value: ReactiveValue::Instruction(InstructionValue::MethodCall { + receiver: Place { + identifier: symbol_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }, + property: Place { + identifier: for_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }, + args: vec![PlaceOrSpread::Place(Place { + identifier: arg_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + })], + loc, + }), + effects: None, + loc, + }), + // StoreLocal: let earlyReturnValue = sentinel + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: None, + value: ReactiveValue::Instruction(InstructionValue::StoreLocal { + lvalue: LValue { + kind: InstructionKind::Let, + place: Place { + identifier: early_return.value, + effect: Effect::ConditionallyMutate, + reactive: true, + loc, + }, + }, + value: Place { + identifier: sentinel_temp, + effect: Effect::Unknown, + reactive: false, + loc: None, // GeneratedSource + }, + type_annotation: None, + loc, + }), + effects: None, + loc, + }), + // Label terminal wrapping the original instructions + ReactiveStatement::Terminal(ReactiveTerminalStatement { + label: Some(ReactiveLabel { + id: early_return.label, + implicit: false, + }), + terminal: ReactiveTerminal::Label { + block: original_instructions, + id: EvaluationOrder(0), + loc: None, // GeneratedSource + }, + }), + ]; +} + +// ============================================================================= +// Helper: create a temporary place identifier +// ============================================================================= + +fn create_temporary_place_id( + env: &mut Environment, + loc: Option, +) -> IdentifierId { + let id = env.next_identifier_id(); + env.identifiers[id.0 as usize].loc = loc; + id +} + +fn promote_temporary(env: &mut Environment, identifier_id: IdentifierId) { + let decl_id = env.identifiers[identifier_id.0 as usize].declaration_id; + env.identifiers[identifier_id.0 as usize].name = + Some(IdentifierName::Promoted(format!("#t{}", decl_id.0))); +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_always_invalidating_scopes.rs b/crates/react_compiler_reactive_scopes/src/prune_always_invalidating_scopes.rs new file mode 100644 index 000000000000..d2431747722d --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_always_invalidating_scopes.rs @@ -0,0 +1,151 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PruneAlwaysInvalidatingScopes +//! +//! Some instructions will *always* produce a new value, and unless memoized +//! will *always* invalidate downstream reactive scopes. This pass finds such +//! values and prunes downstream memoization. +//! +//! Corresponds to `src/ReactiveScopes/PruneAlwaysInvalidatingScopes.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, IdentifierId, InstructionValue, PrunedReactiveScopeBlock, + ReactiveFunction, ReactiveInstruction, ReactiveScopeBlock, ReactiveStatement, ReactiveValue, +}; + +use crate::visitors::{transform_reactive_function, ReactiveFunctionTransform, Transformed}; + +/// Prunes scopes that always invalidate because they depend on unmemoized +/// always-invalidating values. +/// TS: `pruneAlwaysInvalidatingScopes` +pub fn prune_always_invalidating_scopes( + func: &mut ReactiveFunction, + env: &Environment, +) -> Result<(), react_compiler_diagnostics::CompilerError> { + let mut transform = Transform { + env, + always_invalidating_values: HashSet::new(), + unmemoized_values: HashSet::new(), + }; + let mut state = false; // withinScope + transform_reactive_function(func, &mut transform, &mut state) +} + +struct Transform<'a> { + env: &'a Environment, + always_invalidating_values: HashSet, + unmemoized_values: HashSet, +} + +impl<'a> ReactiveFunctionTransform for Transform<'a> { + type State = bool; + + // withinScope + + fn env(&self) -> &Environment { + self.env + } + + fn transform_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + within_scope: &mut bool, + ) -> Result, react_compiler_diagnostics::CompilerError> { + self.visit_instruction(instruction, within_scope)?; + + let lvalue = &instruction.lvalue; + match &instruction.value { + ReactiveValue::Instruction( + InstructionValue::ArrayExpression { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::JsxExpression { .. } + | InstructionValue::JsxFragment { .. } + | InstructionValue::NewExpression { .. }, + ) => { + if let Some(lv) = lvalue { + self.always_invalidating_values.insert(lv.identifier); + if !*within_scope { + self.unmemoized_values.insert(lv.identifier); + } + } + } + ReactiveValue::Instruction(InstructionValue::StoreLocal { + value: store_value, + lvalue: store_lvalue, + .. + }) => { + if self + .always_invalidating_values + .contains(&store_value.identifier) + { + self.always_invalidating_values + .insert(store_lvalue.place.identifier); + } + if self.unmemoized_values.contains(&store_value.identifier) { + self.unmemoized_values.insert(store_lvalue.place.identifier); + } + } + ReactiveValue::Instruction(InstructionValue::LoadLocal { place, .. }) => { + if let Some(lv) = lvalue { + if self.always_invalidating_values.contains(&place.identifier) { + self.always_invalidating_values.insert(lv.identifier); + } + if self.unmemoized_values.contains(&place.identifier) { + self.unmemoized_values.insert(lv.identifier); + } + } + } + _ => {} + } + Ok(Transformed::Keep) + } + + fn transform_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + _within_scope: &mut bool, + ) -> Result, react_compiler_diagnostics::CompilerError> { + let mut within_scope = true; + self.visit_scope(scope, &mut within_scope)?; + + let scope_id = scope.scope; + let scope_data = &self.env.scopes[scope_id.0 as usize]; + + for dep in &scope_data.dependencies { + if self.unmemoized_values.contains(&dep.identifier) { + // This scope depends on an always-invalidating value, prune it + // Propagate always-invalidating and unmemoized to declarations/reassignments + let decl_ids: Vec = scope_data + .declarations + .iter() + .map(|(_, decl)| decl.identifier) + .collect(); + let reassign_ids: Vec = scope_data.reassignments.clone(); + + for id in &decl_ids { + if self.always_invalidating_values.contains(id) { + self.unmemoized_values.insert(*id); + } + } + for id in &reassign_ids { + if self.always_invalidating_values.contains(id) { + self.unmemoized_values.insert(*id); + } + } + + return Ok(Transformed::Replace(ReactiveStatement::PrunedScope( + PrunedReactiveScopeBlock { + scope: scope.scope, + instructions: std::mem::take(&mut scope.instructions), + }, + ))); + } + } + Ok(Transformed::Keep) + } +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_hoisted_contexts.rs b/crates/react_compiler_reactive_scopes/src/prune_hoisted_contexts.rs new file mode 100644 index 000000000000..fadec2049f10 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_hoisted_contexts.rs @@ -0,0 +1,203 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PruneHoistedContexts — removes hoisted context variable declarations +//! and transforms references to their original instruction kinds. +//! +//! Corresponds to `src/ReactiveScopes/PruneHoistedContexts.ts`. + +use std::collections::HashMap; + +use react_compiler_diagnostics::{CompilerError, CompilerErrorDetail, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, EvaluationOrder, IdentifierId, InstructionKind, InstructionValue, + Place, ReactiveFunction, ReactiveInstruction, ReactiveScopeBlock, ReactiveStatement, + ReactiveValue, +}; + +use crate::visitors::{transform_reactive_function, ReactiveFunctionTransform, Transformed}; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Prunes DeclareContexts lowered for HoistedConsts and transforms any +/// references back to their original instruction kind. +/// TS: `pruneHoistedContexts` +pub fn prune_hoisted_contexts( + func: &mut ReactiveFunction, + env: &Environment, +) -> Result<(), CompilerError> { + let mut transform = Transform { env }; + let mut state = VisitorState { + active_scopes: Vec::new(), + uninitialized: HashMap::new(), + }; + transform_reactive_function(func, &mut transform, &mut state) +} + +// ============================================================================= +// State +// ============================================================================= + +#[derive(Debug, Clone)] +enum UninitializedKind { + UnknownKind, + Func { definition: Option }, +} + +struct VisitorState { + active_scopes: Vec>, + uninitialized: HashMap, +} + +impl VisitorState { + fn find_in_active_scopes(&self, id: IdentifierId) -> bool { + for scope in &self.active_scopes { + if scope.contains(&id) { + return true; + } + } + false + } +} + +struct Transform<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionTransform for Transform<'a> { + type State = VisitorState; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut VisitorState, + ) -> Result<(), CompilerError> { + let scope_data = &self.env.scopes[scope.scope.0 as usize]; + let decl_ids: std::collections::HashSet = + scope_data.declarations.iter().map(|(id, _)| *id).collect(); + + // Add declared but not initialized variables + for (_, decl) in &scope_data.declarations { + state + .uninitialized + .insert(decl.identifier, UninitializedKind::UnknownKind); + } + + state.active_scopes.push(decl_ids); + self.traverse_scope(scope, state)?; + state.active_scopes.pop(); + + // Clean up uninitialized after scope + let scope_data = &self.env.scopes[scope.scope.0 as usize]; + for (_, decl) in &scope_data.declarations { + state.uninitialized.remove(&decl.identifier); + } + Ok(()) + } + + fn visit_place( + &mut self, + _id: EvaluationOrder, + place: &Place, + state: &mut VisitorState, + ) -> Result<(), CompilerError> { + if let Some(kind) = state.uninitialized.get(&place.identifier) { + if let UninitializedKind::Func { definition } = kind { + if *definition != Some(place.identifier) { + let mut err = CompilerError::new(); + err.push_error_detail( + CompilerErrorDetail::new( + ErrorCategory::Todo, + "[PruneHoistedContexts] Rewrite hoisted function references" + .to_string(), + ) + .with_loc(place.loc), + ); + return Err(err); + } + } + } + Ok(()) + } + + fn transform_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut VisitorState, + ) -> Result, CompilerError> { + // Remove hoisted declarations to preserve TDZ + if let ReactiveValue::Instruction(InstructionValue::DeclareContext { lvalue, .. }) = + &instruction.value + { + let maybe_non_hoisted = convert_hoisted_lvalue_kind(lvalue.kind); + if let Some(non_hoisted) = maybe_non_hoisted { + if non_hoisted == InstructionKind::Function + && state.uninitialized.contains_key(&lvalue.place.identifier) + { + state.uninitialized.insert( + lvalue.place.identifier, + UninitializedKind::Func { definition: None }, + ); + } + return Ok(Transformed::Remove); + } + } + + if let ReactiveValue::Instruction(InstructionValue::StoreContext { lvalue, .. }) = + &mut instruction.value + { + if lvalue.kind != InstructionKind::Reassign { + let lvalue_id = lvalue.place.identifier; + let is_declared_by_scope = state.find_in_active_scopes(lvalue_id); + if is_declared_by_scope { + if lvalue.kind == InstructionKind::Let || lvalue.kind == InstructionKind::Const + { + lvalue.kind = InstructionKind::Reassign; + } else if lvalue.kind == InstructionKind::Function { + if let Some(kind) = state.uninitialized.get(&lvalue_id) { + assert!( + matches!(kind, UninitializedKind::Func { .. }), + "[PruneHoistedContexts] Unexpected hoisted function" + ); + // References to hoisted functions are now "safe" as + // variable assignments have finished. + state.uninitialized.remove(&lvalue_id); + } + } else { + let mut err = CompilerError::new(); + err.push_error_detail( + CompilerErrorDetail::new( + ErrorCategory::Todo, + "[PruneHoistedContexts] Unexpected kind".to_string(), + ) + .with_loc(instruction.loc), + ); + return Err(err); + } + } + } + } + + self.visit_instruction(instruction, state)?; + Ok(Transformed::Keep) + } +} + +/// Corresponds to TS `convertHoistedLValueKind` — returns None for non-hoisted +/// kinds. +fn convert_hoisted_lvalue_kind(kind: InstructionKind) -> Option { + match kind { + InstructionKind::HoistedLet => Some(InstructionKind::Let), + InstructionKind::HoistedConst => Some(InstructionKind::Const), + InstructionKind::HoistedFunction => Some(InstructionKind::Function), + _ => None, + } +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_non_escaping_scopes.rs b/crates/react_compiler_reactive_scopes/src/prune_non_escaping_scopes.rs new file mode 100644 index 000000000000..5481e81cfbcb --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_non_escaping_scopes.rs @@ -0,0 +1,1315 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PruneNonEscapingScopes — prunes reactive scopes that are not necessary +//! to bound downstream computation. +//! +//! Corresponds to `src/ReactiveScopes/PruneNonEscapingScopes.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + environment::Environment, visitors::each_instruction_value_operand, ArrayPatternElement, + DeclarationId, Effect, EvaluationOrder, IdentifierId, InstructionKind, InstructionValue, + JsxAttribute, JsxTag, ObjectPropertyOrSpread, Pattern, Place, PlaceOrSpread, ReactiveFunction, + ReactiveInstruction, ReactiveScopeBlock, ReactiveStatement, ReactiveTerminal, + ReactiveTerminalStatement, ReactiveValue, ScopeId, +}; + +use crate::visitors::{ + transform_reactive_function, visit_reactive_function, ReactiveFunctionTransform, + ReactiveFunctionVisitor, Transformed, +}; + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Prunes reactive scopes whose outputs don't escape. +/// TS: `pruneNonEscapingScopes` +pub fn prune_non_escaping_scopes( + func: &mut ReactiveFunction, + env: &mut Environment, +) -> Result<(), react_compiler_diagnostics::CompilerError> { + // First build up a map of which instructions are involved in creating which + // values, and which values are returned. + let mut state = CollectState::new(); + for param in &func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p, + react_compiler_hir::ParamPattern::Spread(s) => &s.place, + }; + let identifier = &env.identifiers[place.identifier.0 as usize]; + state.declare(identifier.declaration_id); + } + let visitor = CollectDependenciesVisitor::new(env); + let mut visitor_state = (state, Vec::::new()); + visit_reactive_function(func, &visitor, &mut visitor_state); + let (state, _) = visitor_state; + + // Then walk outward from the returned values and find all captured operands. + let memoized = compute_memoized_identifiers(&state); + + // Prune scopes that do not declare/reassign any escaping values + let mut transform = PruneScopesTransform { + env, + pruned_scopes: HashSet::new(), + reassignments: HashMap::new(), + }; + let mut memoized_state = memoized; + transform_reactive_function(func, &mut transform, &mut memoized_state) +} + +// ============================================================================= +// MemoizationLevel +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum MemoizationLevel { + /// The value should be memoized if it escapes + Memoized, + /// Values that are memoized if their dependencies are memoized + Conditional, + /// Values that cannot be compared with Object.is, but which by default + /// don't need to be memoized + Unmemoized, + /// The value will never be memoized: used for values that can be cheaply + /// compared w Object.is + Never, +} + +/// Given an identifier that appears as an lvalue multiple times with different +/// memoization levels, determines the final memoization level. +fn join_aliases(kind1: MemoizationLevel, kind2: MemoizationLevel) -> MemoizationLevel { + if kind1 == MemoizationLevel::Memoized || kind2 == MemoizationLevel::Memoized { + MemoizationLevel::Memoized + } else if kind1 == MemoizationLevel::Conditional || kind2 == MemoizationLevel::Conditional { + MemoizationLevel::Conditional + } else if kind1 == MemoizationLevel::Unmemoized || kind2 == MemoizationLevel::Unmemoized { + MemoizationLevel::Unmemoized + } else { + MemoizationLevel::Never + } +} + +// ============================================================================= +// Graph nodes +// ============================================================================= + +/// A node in the graph describing the memoization level of a given identifier +/// as well as its dependencies and scopes. +struct IdentifierNode { + level: MemoizationLevel, + memoized: bool, + dependencies: HashSet, + scopes: HashSet, + seen: bool, +} + +/// A scope node describing its dependencies. +struct ScopeNode { + dependencies: Vec, + seen: bool, +} + +// ============================================================================= +// CollectState (TS: State class) +// ============================================================================= + +struct CollectState { + /// Maps lvalues for LoadLocal to the identifier being loaded, to resolve + /// indirections. + definitions: HashMap, + identifiers: HashMap, + scopes: HashMap, + escaping_values: HashSet, +} + +impl CollectState { + fn new() -> Self { + CollectState { + definitions: HashMap::new(), + identifiers: HashMap::new(), + scopes: HashMap::new(), + escaping_values: HashSet::new(), + } + } + + /// Declare a new identifier, used for function id and params. + fn declare(&mut self, id: DeclarationId) { + self.identifiers.insert( + id, + IdentifierNode { + level: MemoizationLevel::Never, + memoized: false, + dependencies: HashSet::new(), + scopes: HashSet::new(), + seen: false, + }, + ); + } + + /// Associates the identifier with its scope, if there is one and it is + /// active for the given instruction id. + fn visit_operand( + &mut self, + env: &Environment, + id: EvaluationOrder, + place: &Place, + identifier: DeclarationId, + ) { + if let Some(scope_id) = get_place_scope(env, id, place.identifier) { + let node = self.scopes.entry(scope_id).or_insert_with(|| { + let scope_data = &env.scopes[scope_id.0 as usize]; + let dependencies = scope_data + .dependencies + .iter() + .map(|dep| env.identifiers[dep.identifier.0 as usize].declaration_id) + .collect(); + ScopeNode { + dependencies, + seen: false, + } + }); + // Avoid unused variable warning — we needed the entry to exist + let _ = node; + let identifier_node = self + .identifiers + .get_mut(&identifier) + .expect("Expected identifier to be initialized"); + identifier_node.scopes.insert(scope_id); + } + } + + /// Resolve an identifier through definitions (LoadLocal indirections). + fn resolve(&self, id: DeclarationId) -> DeclarationId { + self.definitions.get(&id).copied().unwrap_or(id) + } +} + +// ============================================================================= +// MemoizationOptions +// ============================================================================= + +struct MemoizationOptions { + memoize_jsx_elements: bool, + force_memoize_primitives: bool, +} + +// ============================================================================= +// LValueMemoization +// ============================================================================= + +struct LValueMemoization { + place_identifier: IdentifierId, + level: MemoizationLevel, +} + +// ============================================================================= +// Helper: get_place_scope +// ============================================================================= + +fn get_place_scope( + env: &Environment, + id: EvaluationOrder, + identifier_id: IdentifierId, +) -> Option { + let scope_id = env.identifiers[identifier_id.0 as usize].scope?; + if env.scopes[scope_id.0 as usize].range.contains(id) { + Some(scope_id) + } else { + None + } +} + +// ============================================================================= +// Helper: get_function_call_signature (for noAlias check) +// ============================================================================= + +// ============================================================================= +// Helper: compute pattern lvalues +// ============================================================================= + +fn compute_pattern_lvalues(pattern: &Pattern) -> Vec { + let mut lvalues = Vec::new(); + match pattern { + Pattern::Array(array_pattern) => { + for item in &array_pattern.items { + match item { + ArrayPatternElement::Place(place) => { + lvalues.push(LValueMemoization { + place_identifier: place.identifier, + level: MemoizationLevel::Conditional, + }); + } + ArrayPatternElement::Spread(spread) => { + lvalues.push(LValueMemoization { + place_identifier: spread.place.identifier, + level: MemoizationLevel::Memoized, + }); + } + ArrayPatternElement::Hole => {} + } + } + } + Pattern::Object(object_pattern) => { + for property in &object_pattern.properties { + match property { + ObjectPropertyOrSpread::Property(prop) => { + lvalues.push(LValueMemoization { + place_identifier: prop.place.identifier, + level: MemoizationLevel::Conditional, + }); + } + ObjectPropertyOrSpread::Spread(spread) => { + lvalues.push(LValueMemoization { + place_identifier: spread.place.identifier, + level: MemoizationLevel::Memoized, + }); + } + } + } + } + } + lvalues +} + +// ============================================================================= +// CollectDependenciesVisitor +// ============================================================================= + +struct CollectDependenciesVisitor<'a> { + env: &'a Environment, + options: MemoizationOptions, +} + +impl<'a> CollectDependenciesVisitor<'a> { + fn new(env: &'a Environment) -> Self { + CollectDependenciesVisitor { + env, + options: MemoizationOptions { + memoize_jsx_elements: !env.config.enable_forest, + force_memoize_primitives: env.config.enable_forest + || env.enable_preserve_existing_memoization_guarantees, + }, + } + } + + /// Given a value, returns a description of how it should be memoized. + fn compute_memoization_inputs( + &self, + id: EvaluationOrder, + value: &ReactiveValue, + lvalue: Option, + state: &mut CollectState, + ) -> (Vec, Vec<(IdentifierId, EvaluationOrder)>) { + match value { + ReactiveValue::ConditionalExpression { + consequent, + alternate, + .. + } => { + let (_, cons_rvalues) = + self.compute_memoization_inputs(id, consequent, None, state); + let (_, alt_rvalues) = self.compute_memoization_inputs(id, alternate, None, state); + let mut rvalues = cons_rvalues; + rvalues.extend(alt_rvalues); + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } + ReactiveValue::LogicalExpression { left, right, .. } => { + let (_, left_rvalues) = self.compute_memoization_inputs(id, left, None, state); + let (_, right_rvalues) = self.compute_memoization_inputs(id, right, None, state); + let mut rvalues = left_rvalues; + rvalues.extend(right_rvalues); + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions { + self.visit_value_for_memoization( + instr.id, + &instr.value, + instr.lvalue.as_ref().map(|lv| lv.identifier), + state, + ); + } + let (_, rvalues) = self.compute_memoization_inputs(id, inner, None, state); + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + let (_, rvalues) = self.compute_memoization_inputs(id, inner, None, state); + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } + ReactiveValue::Instruction(instr_value) => { + self.compute_instruction_memoization_inputs(id, instr_value, lvalue) + } + } + } + + /// Compute memoization inputs for an InstructionValue. + fn compute_instruction_memoization_inputs( + &self, + id: EvaluationOrder, + value: &InstructionValue, + lvalue: Option, + ) -> (Vec, Vec<(IdentifierId, EvaluationOrder)>) { + let env = self.env; + let options = &self.options; + + match value { + InstructionValue::JsxExpression { + tag, + props, + children, + .. + } => { + let mut rvalues: Vec<(IdentifierId, EvaluationOrder)> = Vec::new(); + if let JsxTag::Place(place) = tag { + rvalues.push((place.identifier, id)); + } + for prop in props { + match prop { + JsxAttribute::Attribute { place, .. } => { + rvalues.push((place.identifier, id)); + } + JsxAttribute::SpreadAttribute { argument, .. } => { + rvalues.push((argument.identifier, id)); + } + } + } + if let Some(children) = children { + for child in children { + rvalues.push((child.identifier, id)); + } + } + let level = if options.memoize_jsx_elements { + MemoizationLevel::Memoized + } else { + MemoizationLevel::Unmemoized + }; + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } + InstructionValue::JsxFragment { children, .. } => { + let level = if options.memoize_jsx_elements { + MemoizationLevel::Memoized + } else { + MemoizationLevel::Unmemoized + }; + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + children.iter().map(|c| (c.identifier, id)).collect(); + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } + InstructionValue::NextPropertyOf { .. } + | InstructionValue::StartMemoize { .. } + | InstructionValue::FinishMemoize { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::ComputedDelete { .. } + | InstructionValue::PropertyDelete { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::UnaryExpression { .. } => { + if options.force_memoize_primitives { + let level = MemoizationLevel::Conditional; + let operands = each_instruction_value_operand(value, env); + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + operands.iter().map(|p| (p.identifier, id)).collect(); + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level, + }] + } else { + vec![] + }; + (lvalues, rvalues) + } else { + let level = MemoizationLevel::Never; + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level, + }] + } else { + vec![] + }; + (lvalues, vec![]) + } + } + InstructionValue::Await { value: inner, .. } + | InstructionValue::TypeCastExpression { value: inner, .. } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, vec![(inner.identifier, id)]) + } + InstructionValue::IteratorNext { + iterator, + collection, + .. + } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + ( + lvalues, + vec![(iterator.identifier, id), (collection.identifier, id)], + ) + } + InstructionValue::GetIterator { collection, .. } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, vec![(collection.identifier, id)]) + } + InstructionValue::LoadLocal { place, .. } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, vec![(place.identifier, id)]) + } + InstructionValue::LoadContext { place, .. } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }] + } else { + vec![] + }; + (lvalues, vec![(place.identifier, id)]) + } + InstructionValue::DeclareContext { + lvalue: decl_lvalue, + .. + } => { + let mut lvalues = vec![LValueMemoization { + place_identifier: decl_lvalue.place.identifier, + level: MemoizationLevel::Memoized, + }]; + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Unmemoized, + }); + } + (lvalues, vec![]) + } + InstructionValue::DeclareLocal { + lvalue: decl_lvalue, + .. + } => { + let mut lvalues = vec![LValueMemoization { + place_identifier: decl_lvalue.place.identifier, + level: MemoizationLevel::Unmemoized, + }]; + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Unmemoized, + }); + } + (lvalues, vec![]) + } + InstructionValue::PrefixUpdate { + lvalue: upd_lvalue, + value: upd_value, + .. + } + | InstructionValue::PostfixUpdate { + lvalue: upd_lvalue, + value: upd_value, + .. + } => { + let mut lvalues = vec![LValueMemoization { + place_identifier: upd_lvalue.identifier, + level: MemoizationLevel::Conditional, + }]; + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }); + } + (lvalues, vec![(upd_value.identifier, id)]) + } + InstructionValue::StoreLocal { + lvalue: store_lvalue, + value: store_value, + .. + } => { + let mut lvalues = vec![LValueMemoization { + place_identifier: store_lvalue.place.identifier, + level: MemoizationLevel::Conditional, + }]; + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }); + } + (lvalues, vec![(store_value.identifier, id)]) + } + InstructionValue::StoreContext { + lvalue: store_lvalue, + value: store_value, + .. + } => { + let mut lvalues = vec![LValueMemoization { + place_identifier: store_lvalue.place.identifier, + level: MemoizationLevel::Memoized, + }]; + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }); + } + (lvalues, vec![(store_value.identifier, id)]) + } + InstructionValue::StoreGlobal { + value: store_value, .. + } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Unmemoized, + }] + } else { + vec![] + }; + (lvalues, vec![(store_value.identifier, id)]) + } + InstructionValue::Destructure { + lvalue: dest_lvalue, + value: dest_value, + .. + } => { + let mut lvalues = Vec::new(); + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }); + } + lvalues.extend(compute_pattern_lvalues(&dest_lvalue.pattern)); + (lvalues, vec![(dest_value.identifier, id)]) + } + InstructionValue::ComputedLoad { object, .. } + | InstructionValue::PropertyLoad { object, .. } => { + let level = MemoizationLevel::Conditional; + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level, + }] + } else { + vec![] + }; + (lvalues, vec![(object.identifier, id)]) + } + InstructionValue::ComputedStore { + object, + value: store_value, + .. + } => { + let mut lvalues = vec![LValueMemoization { + place_identifier: object.identifier, + level: MemoizationLevel::Conditional, + }]; + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Conditional, + }); + } + (lvalues, vec![(store_value.identifier, id)]) + } + InstructionValue::TaggedTemplateExpression { tag, .. } => { + let no_alias = env.has_no_alias_signature(tag.identifier); + let mut lvalues = Vec::new(); + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Memoized, + }); + } + if no_alias { + return (lvalues, vec![]); + } + let operands = each_instruction_value_operand(value, env); + for op in &operands { + if op.effect.is_mutable() { + lvalues.push(LValueMemoization { + place_identifier: op.identifier, + level: MemoizationLevel::Memoized, + }); + } + } + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + operands.iter().map(|p| (p.identifier, id)).collect(); + (lvalues, rvalues) + } + InstructionValue::CallExpression { callee, .. } => { + let no_alias = env.has_no_alias_signature(callee.identifier); + let mut lvalues = Vec::new(); + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Memoized, + }); + } + if no_alias { + return (lvalues, vec![]); + } + let operands = each_instruction_value_operand(value, env); + for op in &operands { + if op.effect.is_mutable() { + lvalues.push(LValueMemoization { + place_identifier: op.identifier, + level: MemoizationLevel::Memoized, + }); + } + } + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + operands.iter().map(|p| (p.identifier, id)).collect(); + (lvalues, rvalues) + } + InstructionValue::MethodCall { property, .. } => { + let no_alias = env.has_no_alias_signature(property.identifier); + let mut lvalues = Vec::new(); + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Memoized, + }); + } + if no_alias { + return (lvalues, vec![]); + } + let operands = each_instruction_value_operand(value, env); + for op in &operands { + if op.effect.is_mutable() { + lvalues.push(LValueMemoization { + place_identifier: op.identifier, + level: MemoizationLevel::Memoized, + }); + } + } + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + operands.iter().map(|p| (p.identifier, id)).collect(); + (lvalues, rvalues) + } + InstructionValue::RegExpLiteral { .. } + | InstructionValue::ArrayExpression { .. } + | InstructionValue::NewExpression { .. } + | InstructionValue::ObjectExpression { .. } + | InstructionValue::PropertyStore { .. } => { + let operands = each_instruction_value_operand(value, env); + let mut lvalues: Vec = operands + .iter() + .filter(|op| op.effect.is_mutable()) + .map(|op| LValueMemoization { + place_identifier: op.identifier, + level: MemoizationLevel::Memoized, + }) + .collect(); + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Memoized, + }); + } + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + operands.iter().map(|p| (p.identifier, id)).collect(); + (lvalues, rvalues) + } + InstructionValue::ObjectMethod { .. } | InstructionValue::FunctionExpression { .. } => { + // The canonical each_instruction_value_operand already includes context + // (captured variables) for FunctionExpression/ObjectMethod. + let operands = each_instruction_value_operand(value, env); + let mut lvalues: Vec = operands + .iter() + .filter(|op| op.effect.is_mutable()) + .map(|op| LValueMemoization { + place_identifier: op.identifier, + level: MemoizationLevel::Memoized, + }) + .collect(); + if let Some(lv) = lvalue { + lvalues.push(LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Memoized, + }); + } + let rvalues: Vec<(IdentifierId, EvaluationOrder)> = + operands.iter().map(|p| (p.identifier, id)).collect(); + (lvalues, rvalues) + } + InstructionValue::UnsupportedNode { .. } => { + let lvalues = if let Some(lv) = lvalue { + vec![LValueMemoization { + place_identifier: lv, + level: MemoizationLevel::Never, + }] + } else { + vec![] + }; + (lvalues, vec![]) + } + } + } + + fn visit_value_for_memoization( + &self, + id: EvaluationOrder, + value: &ReactiveValue, + lvalue: Option, + state: &mut CollectState, + ) { + let env = self.env; + // Determine the level of memoization for this value and the lvalues/rvalues + let (aliasing_lvalues, aliasing_rvalues) = + self.compute_memoization_inputs(id, value, lvalue, state); + + // Associate all the rvalues with the instruction's scope if it has one + // We need to collect rvalue data first to avoid borrow issues + let rvalue_data: Vec<(IdentifierId, DeclarationId)> = aliasing_rvalues + .iter() + .map(|(identifier_id, _)| { + let decl_id = env.identifiers[identifier_id.0 as usize].declaration_id; + let operand_id = state.resolve(decl_id); + (*identifier_id, operand_id) + }) + .collect(); + + for (identifier_id, operand_id) in &rvalue_data { + // Build the Place data needed for get_place_scope + state.visit_operand( + env, + id, + &Place { + identifier: *identifier_id, + effect: Effect::Read, + reactive: false, + loc: None, + }, + *operand_id, + ); + } + + // Add the operands as dependencies of all lvalues + for lv in &aliasing_lvalues { + let lvalue_decl_id = env.identifiers[lv.place_identifier.0 as usize].declaration_id; + let lvalue_id = state.resolve(lvalue_decl_id); + let node = state + .identifiers + .entry(lvalue_id) + .or_insert_with(|| IdentifierNode { + level: MemoizationLevel::Never, + memoized: false, + dependencies: HashSet::new(), + scopes: HashSet::new(), + seen: false, + }); + node.level = join_aliases(node.level, lv.level); + for (_, operand_id) in &rvalue_data { + if *operand_id == lvalue_id { + continue; + } + node.dependencies.insert(*operand_id); + } + + state.visit_operand( + env, + id, + &Place { + identifier: lv.place_identifier, + effect: Effect::Read, + reactive: false, + loc: None, + }, + lvalue_id, + ); + } + + // Handle LoadLocal definitions and hook calls + if let ReactiveValue::Instruction(instr_value) = value { + if let InstructionValue::LoadLocal { place, .. } = instr_value { + if let Some(lv_id) = lvalue { + let lv_decl = env.identifiers[lv_id.0 as usize].declaration_id; + let place_decl = env.identifiers[place.identifier.0 as usize].declaration_id; + state.definitions.insert(lv_decl, place_decl); + } + } else if let InstructionValue::CallExpression { callee, args, .. } = instr_value { + if env + .get_hook_kind_for_id(callee.identifier) + .ok() + .flatten() + .is_some() + { + let no_alias = env.has_no_alias_signature(callee.identifier); + if !no_alias { + for arg in args { + let place = match arg { + PlaceOrSpread::Spread(spread) => &spread.place, + PlaceOrSpread::Place(place) => place, + }; + let decl = env.identifiers[place.identifier.0 as usize].declaration_id; + state.escaping_values.insert(decl); + } + } + } + } else if let InstructionValue::MethodCall { property, args, .. } = instr_value { + if env + .get_hook_kind_for_id(property.identifier) + .ok() + .flatten() + .is_some() + { + let no_alias = env.has_no_alias_signature(property.identifier); + if !no_alias { + for arg in args { + let place = match arg { + PlaceOrSpread::Spread(spread) => &spread.place, + PlaceOrSpread::Place(place) => place, + }; + let decl = env.identifiers[place.identifier.0 as usize].declaration_id; + state.escaping_values.insert(decl); + } + } + } + } + } + } +} + +// ============================================================================= +// ReactiveFunctionVisitor impl for CollectDependenciesVisitor +// ============================================================================= + +impl<'a> ReactiveFunctionVisitor for CollectDependenciesVisitor<'a> { + type State = (CollectState, Vec); + + fn env(&self) -> &Environment { + self.env + } + + fn visit_instruction(&self, instruction: &ReactiveInstruction, state: &mut Self::State) { + self.visit_value_for_memoization( + instruction.id, + &instruction.value, + instruction.lvalue.as_ref().map(|lv| lv.identifier), + &mut state.0, + ); + } + + fn visit_terminal(&self, stmt: &ReactiveTerminalStatement, state: &mut Self::State) { + // Traverse terminal blocks first (TS: this.traverseTerminal(stmt, scopes)) + self.traverse_terminal(stmt, state); + + // Handle return terminals + if let ReactiveTerminal::Return { value, .. } = &stmt.terminal { + let env = self.env; + let decl = env.identifiers[value.identifier.0 as usize].declaration_id; + state.0.escaping_values.insert(decl); + + // If the return is within a scope, associate those scopes with the returned + // value + let identifier_node = state + .0 + .identifiers + .get_mut(&decl) + .expect("Expected identifier to be initialized"); + for scope_id in &state.1 { + identifier_node.scopes.insert(*scope_id); + } + } + } + + fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut Self::State) { + let env = self.env; + let scope_id = scope.scope; + let scope_data = &env.scopes[scope_id.0 as usize]; + + // If a scope reassigns any variables, set the chain of active scopes as a + // dependency of those variables. + for reassignment_id in &scope_data.reassignments { + let decl = env.identifiers[reassignment_id.0 as usize].declaration_id; + let identifier_node = state + .0 + .identifiers + .get_mut(&decl) + .expect("Expected identifier to be initialized"); + for s in &state.1 { + identifier_node.scopes.insert(*s); + } + identifier_node.scopes.insert(scope_id); + } + + // TS: this.traverseScope(scope, [...scopes, scope.scope]) + state.1.push(scope_id); + self.traverse_scope(scope, state); + state.1.pop(); + } +} + +// ============================================================================= +// computeMemoizedIdentifiers +// ============================================================================= + +fn compute_memoized_identifiers(state: &CollectState) -> HashSet { + let mut memoized = HashSet::new(); + + // We need mutable access to the nodes, so we clone the state into mutable + // structures + let mut identifier_nodes: HashMap< + DeclarationId, + ( + MemoizationLevel, + bool, + HashSet, + HashSet, + bool, + ), + > = state + .identifiers + .iter() + .map(|(id, node)| { + ( + *id, + ( + node.level, + node.memoized, + node.dependencies.clone(), + node.scopes.clone(), + node.seen, + ), + ) + }) + .collect(); + + let mut scope_nodes: HashMap, bool)> = state + .scopes + .iter() + .map(|(id, node)| (*id, (node.dependencies.clone(), node.seen))) + .collect(); + + fn visit( + id: DeclarationId, + force_memoize: bool, + identifier_nodes: &mut HashMap< + DeclarationId, + ( + MemoizationLevel, + bool, + HashSet, + HashSet, + bool, + ), + >, + scope_nodes: &mut HashMap, bool)>, + memoized: &mut HashSet, + ) -> bool { + let (level, _, _, _, seen) = *identifier_nodes + .get(&id) + .expect("Expected a node for all identifiers"); + if seen { + return identifier_nodes.get(&id).unwrap().1; + } + + // Mark as seen, temporarily mark as non-memoized + identifier_nodes.get_mut(&id).unwrap().4 = true; // seen = true + identifier_nodes.get_mut(&id).unwrap().1 = false; // memoized = false + + // Visit dependencies + let deps: Vec = identifier_nodes + .get(&id) + .unwrap() + .2 + .iter() + .copied() + .collect(); + let mut has_memoized_dependency = false; + for dep in deps { + let is_dep_memoized = visit(dep, false, identifier_nodes, scope_nodes, memoized); + has_memoized_dependency |= is_dep_memoized; + } + + if level == MemoizationLevel::Memoized + || (level == MemoizationLevel::Conditional + && (has_memoized_dependency || force_memoize)) + || (level == MemoizationLevel::Unmemoized && force_memoize) + { + identifier_nodes.get_mut(&id).unwrap().1 = true; // memoized = true + memoized.insert(id); + let scopes: Vec = identifier_nodes + .get(&id) + .unwrap() + .3 + .iter() + .copied() + .collect(); + for scope_id in scopes { + force_memoize_scope_dependencies(scope_id, identifier_nodes, scope_nodes, memoized); + } + } + identifier_nodes.get(&id).unwrap().1 + } + + fn force_memoize_scope_dependencies( + id: ScopeId, + identifier_nodes: &mut HashMap< + DeclarationId, + ( + MemoizationLevel, + bool, + HashSet, + HashSet, + bool, + ), + >, + scope_nodes: &mut HashMap, bool)>, + memoized: &mut HashSet, + ) { + let seen = scope_nodes + .get(&id) + .expect("Expected a node for all scopes") + .1; + if seen { + return; + } + scope_nodes.get_mut(&id).unwrap().1 = true; // seen = true + + let deps: Vec = scope_nodes.get(&id).unwrap().0.clone(); + for dep in deps { + visit(dep, true, identifier_nodes, scope_nodes, memoized); + } + } + + // Walk from the "roots" aka returned/escaping identifiers + let escaping: Vec = state.escaping_values.iter().copied().collect(); + for value in escaping { + visit( + value, + false, + &mut identifier_nodes, + &mut scope_nodes, + &mut memoized, + ); + } + + memoized +} + +// ============================================================================= +// PruneScopesTransform +// ============================================================================= + +struct PruneScopesTransform<'a> { + env: &'a Environment, + pruned_scopes: HashSet, + reassignments: HashMap>, +} + +impl<'a> ReactiveFunctionTransform for PruneScopesTransform<'a> { + type State = HashSet; + + fn env(&self) -> &Environment { + self.env + } + + fn transform_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut HashSet, + ) -> Result, react_compiler_diagnostics::CompilerError> { + self.visit_scope(scope, state)?; + + let scope_id = scope.scope; + let scope_data = &self.env.scopes[scope_id.0 as usize]; + + // Keep scopes that appear empty (value being memoized may be early-returned) + // or have early return values + if (scope_data.declarations.is_empty() && scope_data.reassignments.is_empty()) + || scope_data.early_return_value.is_some() + { + return Ok(Transformed::Keep); + } + + let has_memoized_output = scope_data.declarations.iter().any(|(_, decl)| { + let decl_id = self.env.identifiers[decl.identifier.0 as usize].declaration_id; + state.contains(&decl_id) + }) || scope_data.reassignments.iter().any(|reassign_id| { + let decl_id = self.env.identifiers[reassign_id.0 as usize].declaration_id; + state.contains(&decl_id) + }); + + if has_memoized_output { + Ok(Transformed::Keep) + } else { + self.pruned_scopes.insert(scope_id); + Ok(Transformed::ReplaceMany(std::mem::take( + &mut scope.instructions, + ))) + } + } + + fn transform_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut HashSet, + ) -> Result, react_compiler_diagnostics::CompilerError> { + self.traverse_instruction(instruction, state)?; + + match &mut instruction.value { + ReactiveValue::Instruction(InstructionValue::StoreLocal { + value: store_value, + lvalue: store_lvalue, + .. + }) if store_lvalue.kind == InstructionKind::Reassign => { + let decl_id = + self.env.identifiers[store_lvalue.place.identifier.0 as usize].declaration_id; + let ids = self + .reassignments + .entry(decl_id) + .or_insert_with(HashSet::new); + ids.insert(store_value.identifier); + } + ReactiveValue::Instruction(InstructionValue::LoadLocal { place, .. }) => { + let has_scope = self.env.identifiers[place.identifier.0 as usize] + .scope + .is_some(); + let lvalue_no_scope = instruction + .lvalue + .as_ref() + .map(|lv| { + self.env.identifiers[lv.identifier.0 as usize] + .scope + .is_none() + }) + .unwrap_or(false); + if has_scope && lvalue_no_scope { + if let Some(lv) = &instruction.lvalue { + let decl_id = self.env.identifiers[lv.identifier.0 as usize].declaration_id; + let ids = self + .reassignments + .entry(decl_id) + .or_insert_with(HashSet::new); + ids.insert(place.identifier); + } + } + } + ReactiveValue::Instruction(InstructionValue::FinishMemoize { + decl, pruned, .. + }) => { + let decl_has_scope = self.env.identifiers[decl.identifier.0 as usize] + .scope + .is_some(); + if !decl_has_scope { + // If the manual memo was a useMemo that got inlined, iterate through + // all reassignments to the iife temporary to ensure they're memoized. + let decl_id = self.env.identifiers[decl.identifier.0 as usize].declaration_id; + let decls: Vec = self + .reassignments + .get(&decl_id) + .map(|ids| ids.iter().copied().collect()) + .unwrap_or_else(|| vec![decl.identifier]); + + if decls.iter().all(|d| { + let scope = self.env.identifiers[d.0 as usize].scope; + scope.is_none() || self.pruned_scopes.contains(&scope.unwrap()) + }) { + *pruned = true; + } + } else { + let scope = self.env.identifiers[decl.identifier.0 as usize].scope; + if let Some(scope_id) = scope { + if self.pruned_scopes.contains(&scope_id) { + *pruned = true; + } + } + } + } + _ => {} + } + + Ok(Transformed::Keep) + } +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_non_reactive_dependencies.rs b/crates/react_compiler_reactive_scopes/src/prune_non_reactive_dependencies.rs new file mode 100644 index 000000000000..b6717ee67756 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_non_reactive_dependencies.rs @@ -0,0 +1,245 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PruneNonReactiveDependencies + CollectReactiveIdentifiers +//! +//! Corresponds to `src/ReactiveScopes/PruneNonReactiveDependencies.ts` +//! and `src/ReactiveScopes/CollectReactiveIdentifiers.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, is_primitive_type, is_use_ref_type, object_shape, + visitors as hir_visitors, EvaluationOrder, IdentifierId, InstructionValue, Place, + PrunedReactiveScopeBlock, ReactiveFunction, ReactiveInstruction, ReactiveScopeBlock, + ReactiveValue, +}; + +use crate::visitors::{self, ReactiveFunctionTransform, ReactiveFunctionVisitor}; + +// ============================================================================= +// CollectReactiveIdentifiers +// ============================================================================= + +/// Collects identifiers that are reactive. +/// TS: `collectReactiveIdentifiers` +pub fn collect_reactive_identifiers( + func: &ReactiveFunction, + env: &Environment, +) -> HashSet { + let visitor = CollectVisitor { env }; + let mut state = HashSet::new(); + crate::visitors::visit_reactive_function(func, &visitor, &mut state); + state +} + +struct CollectVisitor<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionVisitor for CollectVisitor<'a> { + type State = HashSet; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_lvalue(&self, id: EvaluationOrder, lvalue: &Place, state: &mut Self::State) { + // Visitors don't visit lvalues as places by default, but we want to visit all + // places + self.visit_place(id, lvalue, state); + } + + fn visit_place(&self, _id: EvaluationOrder, place: &Place, state: &mut Self::State) { + if place.reactive { + state.insert(place.identifier); + } + } + + fn visit_pruned_scope(&self, scope: &PrunedReactiveScopeBlock, state: &mut Self::State) { + self.traverse_pruned_scope(scope, state); + + let scope_data = &self.env.scopes[scope.scope.0 as usize]; + for (_id, decl) in &scope_data.declarations { + let identifier = &self.env.identifiers[decl.identifier.0 as usize]; + let ty = &self.env.types[identifier.type_.0 as usize]; + if !is_primitive_type(ty) && !is_stable_ref_type(ty, state, identifier.id) { + state.insert(*_id); + } + } + } +} + +/// TS: `isStableRefType` +fn is_stable_ref_type( + ty: &react_compiler_hir::Type, + reactive_identifiers: &HashSet, + id: IdentifierId, +) -> bool { + is_use_ref_type(ty) && !reactive_identifiers.contains(&id) +} + +// ============================================================================= +// isStableType (ported from HIR.ts) +// ============================================================================= + +/// TS: `isStableType` +fn is_stable_type(ty: &react_compiler_hir::Type) -> bool { + is_set_state_type(ty) + || is_set_action_state_type(ty) + || is_dispatcher_type(ty) + || is_use_ref_type(ty) + || is_start_transition_type(ty) + || is_set_optimistic_type(ty) +} + +fn is_set_state_type(ty: &react_compiler_hir::Type) -> bool { + matches!(ty, react_compiler_hir::Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_SET_STATE_ID) +} + +fn is_set_action_state_type(ty: &react_compiler_hir::Type) -> bool { + matches!(ty, react_compiler_hir::Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_SET_ACTION_STATE_ID) +} + +fn is_dispatcher_type(ty: &react_compiler_hir::Type) -> bool { + matches!(ty, react_compiler_hir::Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_DISPATCH_ID) +} + +fn is_start_transition_type(ty: &react_compiler_hir::Type) -> bool { + matches!(ty, react_compiler_hir::Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_START_TRANSITION_ID) +} + +fn is_set_optimistic_type(ty: &react_compiler_hir::Type) -> bool { + matches!(ty, react_compiler_hir::Type::Function { shape_id: Some(id), .. } if id == object_shape::BUILT_IN_SET_OPTIMISTIC_ID) +} + +// ============================================================================= +// PruneNonReactiveDependencies +// ============================================================================= + +/// Prunes dependencies that are guaranteed to be non-reactive. +/// TS: `pruneNonReactiveDependencies` +pub fn prune_non_reactive_dependencies(func: &mut ReactiveFunction, env: &mut Environment) { + let reactive_ids = collect_reactive_identifiers(func, env); + let mut visitor = PruneVisitor { env }; + let mut state = reactive_ids; + visitors::transform_reactive_function(func, &mut visitor, &mut state) + .expect("PruneNonReactiveDependencies should not fail"); +} + +struct PruneVisitor<'a> { + env: &'a mut Environment, +} + +impl<'a> ReactiveFunctionTransform for PruneVisitor<'a> { + type State = HashSet; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut Self::State, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + self.traverse_instruction(instruction, state)?; + + let lvalue = &instruction.lvalue; + match &instruction.value { + ReactiveValue::Instruction(InstructionValue::LoadLocal { place, .. }) => { + if let Some(lv) = lvalue { + if state.contains(&place.identifier) { + state.insert(lv.identifier); + } + } + } + ReactiveValue::Instruction(InstructionValue::StoreLocal { + value: store_value, + lvalue: store_lvalue, + .. + }) => { + if state.contains(&store_value.identifier) { + state.insert(store_lvalue.place.identifier); + if let Some(lv) = lvalue { + state.insert(lv.identifier); + } + } + } + ReactiveValue::Instruction(InstructionValue::Destructure { + value: destr_value, + lvalue: destr_lvalue, + .. + }) => { + if state.contains(&destr_value.identifier) { + for operand in hir_visitors::each_pattern_operand(&destr_lvalue.pattern) { + let ident = &self.env.identifiers[operand.identifier.0 as usize]; + let ty = &self.env.types[ident.type_.0 as usize]; + if is_stable_type(ty) { + continue; + } + state.insert(operand.identifier); + } + if let Some(lv) = lvalue { + state.insert(lv.identifier); + } + } + } + ReactiveValue::Instruction(InstructionValue::PropertyLoad { object, .. }) => { + if let Some(lv) = lvalue { + let ident = &self.env.identifiers[lv.identifier.0 as usize]; + let ty = &self.env.types[ident.type_.0 as usize]; + if state.contains(&object.identifier) && !is_stable_type(ty) { + state.insert(lv.identifier); + } + } + } + ReactiveValue::Instruction(InstructionValue::ComputedLoad { + object, property, .. + }) => { + if let Some(lv) = lvalue { + if state.contains(&object.identifier) || state.contains(&property.identifier) { + state.insert(lv.identifier); + } + } + } + _ => {} + } + Ok(()) + } + + fn visit_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut Self::State, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + self.traverse_scope(scope, state)?; + + let scope_id = scope.scope; + let scope_data = &mut self.env.scopes[scope_id.0 as usize]; + + // Remove non-reactive dependencies + scope_data + .dependencies + .retain(|dep| state.contains(&dep.identifier)); + + // If any deps remain, mark all declarations and reassignments as reactive + if !scope_data.dependencies.is_empty() { + let decl_ids: Vec = scope_data + .declarations + .iter() + .map(|(_, decl)| decl.identifier) + .collect(); + for id in decl_ids { + state.insert(id); + } + let reassign_ids: Vec = scope_data.reassignments.clone(); + for id in reassign_ids { + state.insert(id); + } + } + Ok(()) + } +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_unused_labels.rs b/crates/react_compiler_reactive_scopes/src/prune_unused_labels.rs new file mode 100644 index 000000000000..06cc02a6fb2e --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_unused_labels.rs @@ -0,0 +1,91 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Flattens labeled terminals where the label is not reachable, and +//! nulls out labels for other terminals where the label is unused. +//! +//! Corresponds to `src/ReactiveScopes/PruneUnusedLabels.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, BlockId, ReactiveFunction, ReactiveStatement, ReactiveTerminal, + ReactiveTerminalStatement, ReactiveTerminalTargetKind, +}; + +use crate::visitors::{transform_reactive_function, ReactiveFunctionTransform, Transformed}; + +/// Prune unused labels from a reactive function. +pub fn prune_unused_labels( + func: &mut ReactiveFunction, + env: &Environment, +) -> Result<(), react_compiler_diagnostics::CompilerError> { + let mut transform = Transform { env }; + let mut labels: HashSet = HashSet::new(); + transform_reactive_function(func, &mut transform, &mut labels) +} + +struct Transform<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionTransform for Transform<'a> { + type State = HashSet; + + fn env(&self) -> &Environment { + self.env + } + + fn transform_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut HashSet, + ) -> Result, react_compiler_diagnostics::CompilerError> { + // Traverse children first + self.traverse_terminal(stmt, state)?; + + // Collect labeled break/continue targets + match &stmt.terminal { + ReactiveTerminal::Break { + target, + target_kind: ReactiveTerminalTargetKind::Labeled, + .. + } + | ReactiveTerminal::Continue { + target, + target_kind: ReactiveTerminalTargetKind::Labeled, + .. + } => { + state.insert(*target); + } + _ => {} + } + + // Is this terminal reachable via a break/continue to its label? + let is_reachable_label = stmt + .label + .as_ref() + .map_or(false, |label| state.contains(&label.id)); + + if let ReactiveTerminal::Label { block, .. } = &mut stmt.terminal { + if !is_reachable_label { + // Flatten labeled terminals where the label isn't necessary. + // Note: In TS, there's a check for `last.terminal.target === null` + // to pop a trailing break, but since target is always a BlockId (number), + // that check is always false, so the trailing break is never removed. + let flattened = std::mem::take(block); + return Ok(Transformed::ReplaceMany(flattened)); + } + } + + if !is_reachable_label { + if let Some(label) = &mut stmt.label { + label.implicit = true; + } + } + + Ok(Transformed::Keep) + } +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_unused_lvalues.rs b/crates/react_compiler_reactive_scopes/src/prune_unused_lvalues.rs new file mode 100644 index 000000000000..9c37f58e12e6 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_unused_lvalues.rs @@ -0,0 +1,238 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PruneUnusedLValues (PruneTemporaryLValues) +//! +//! Nulls out lvalues for temporary variables that are never accessed later. +//! +//! Corresponds to `src/ReactiveScopes/PruneTemporaryLValues.ts`. + +use std::collections::HashSet; + +use react_compiler_hir::{ + environment::Environment, DeclarationId, EvaluationOrder, Place, ReactiveFunction, + ReactiveInstruction, ReactiveStatement, ReactiveValue, +}; + +use crate::visitors::{self, ReactiveFunctionVisitor}; + +/// Nulls out lvalues for unnamed temporaries that are never used. +/// TS: `pruneUnusedLValues` +/// +/// Uses ReactiveFunctionVisitor to collect unnamed lvalue DeclarationIds, +/// removing them when referenced as operands. After the visitor pass, +/// a second pass nulls out the remaining unused lvalues. +/// +/// This uses a two-phase approach because Rust's ReactiveFunctionVisitor +/// takes immutable references, so we cannot modify lvalues during the visit. +/// The TS version stores mutable instruction references and modifies them +/// after the visitor completes. +pub fn prune_unused_lvalues(func: &mut ReactiveFunction, env: &Environment) { + // Phase 1: Use ReactiveFunctionVisitor to identify unused unnamed lvalues. + // When we see an unnamed lvalue on an instruction, we add its DeclarationId. + // When we see a place reference (operand), we remove its DeclarationId. + let visitor = Visitor { env }; + let mut lvalues: HashSet = HashSet::new(); + visitors::visit_reactive_function(func, &visitor, &mut lvalues); + + // Phase 2: Null out lvalues whose DeclarationId remains in the map. + // In the TS, this is done by iterating the stored instruction references. + // In Rust, we walk the tree to find instructions with matching DeclarationIds. + if !lvalues.is_empty() { + null_unused_lvalues(&mut func.body, env, &lvalues); + } +} + +/// TS: `type LValues = Map` +/// In Rust, we only need the set of DeclarationIds (not the instruction refs) +/// because we apply changes in a separate pass. +type LValues = HashSet; + +/// TS: `class Visitor extends ReactiveFunctionVisitor` +struct Visitor<'a> { + env: &'a Environment, +} + +impl ReactiveFunctionVisitor for Visitor<'_> { + type State = LValues; + + fn env(&self) -> &Environment { + self.env + } + + /// TS: `visitPlace(_id, place, state) { + /// state.delete(place.identifier.declarationId) }` + fn visit_place(&self, _id: EvaluationOrder, place: &Place, state: &mut LValues) { + let ident = &self.env.identifiers[place.identifier.0 as usize]; + state.remove(&ident.declaration_id); + } + + /// TS: `visitInstruction(instruction, state)` + /// Calls traverseInstruction first (visits operands via visitPlace), + /// then checks if the lvalue is unnamed and adds to map. + fn visit_instruction(&self, instruction: &ReactiveInstruction, state: &mut LValues) { + self.traverse_instruction(instruction, state); + if let Some(lv) = &instruction.lvalue { + let ident = &self.env.identifiers[lv.identifier.0 as usize]; + if ident.name.is_none() { + state.insert(ident.declaration_id); + } + } + } +} + +/// Phase 2: Walk the tree and null out lvalues whose DeclarationId is unused. +/// This is necessary because Rust's visitor takes immutable references. +fn null_unused_lvalues( + block: &mut Vec, + env: &Environment, + unused: &HashSet, +) { + for stmt in block.iter_mut() { + match stmt { + ReactiveStatement::Instruction(instr) => { + null_unused_in_instruction(instr, env, unused); + } + ReactiveStatement::Scope(scope) => { + null_unused_lvalues(&mut scope.instructions, env, unused); + } + ReactiveStatement::PrunedScope(scope) => { + null_unused_lvalues(&mut scope.instructions, env, unused); + } + ReactiveStatement::Terminal(stmt) => { + null_unused_in_terminal(&mut stmt.terminal, env, unused); + } + } + } +} + +fn null_unused_in_instruction( + instr: &mut ReactiveInstruction, + env: &Environment, + unused: &HashSet, +) { + if let Some(lv) = &instr.lvalue { + let ident = &env.identifiers[lv.identifier.0 as usize]; + if unused.contains(&ident.declaration_id) { + instr.lvalue = None; + } + } + null_unused_in_value(&mut instr.value, env, unused); +} + +fn null_unused_in_value( + value: &mut ReactiveValue, + env: &Environment, + unused: &HashSet, +) { + match value { + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions.iter_mut() { + null_unused_in_instruction(instr, env, unused); + } + null_unused_in_value(inner, env, unused); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + null_unused_in_value(left, env, unused); + null_unused_in_value(right, env, unused); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + null_unused_in_value(test, env, unused); + null_unused_in_value(consequent, env, unused); + null_unused_in_value(alternate, env, unused); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + null_unused_in_value(inner, env, unused); + } + ReactiveValue::Instruction(_) => {} + } +} + +fn null_unused_in_terminal( + terminal: &mut react_compiler_hir::ReactiveTerminal, + env: &Environment, + unused: &HashSet, +) { + use react_compiler_hir::ReactiveTerminal; + match terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { .. } | ReactiveTerminal::Throw { .. } => {} + ReactiveTerminal::For { + init, + test, + update, + loop_block, + .. + } => { + null_unused_in_value(init, env, unused); + null_unused_in_value(test, env, unused); + null_unused_lvalues(loop_block, env, unused); + if let Some(update) = update { + null_unused_in_value(update, env, unused); + } + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + .. + } => { + null_unused_in_value(init, env, unused); + null_unused_in_value(test, env, unused); + null_unused_lvalues(loop_block, env, unused); + } + ReactiveTerminal::ForIn { + init, loop_block, .. + } => { + null_unused_in_value(init, env, unused); + null_unused_lvalues(loop_block, env, unused); + } + ReactiveTerminal::DoWhile { + loop_block, test, .. + } => { + null_unused_lvalues(loop_block, env, unused); + null_unused_in_value(test, env, unused); + } + ReactiveTerminal::While { + test, loop_block, .. + } => { + null_unused_in_value(test, env, unused); + null_unused_lvalues(loop_block, env, unused); + } + ReactiveTerminal::If { + consequent, + alternate, + .. + } => { + null_unused_lvalues(consequent, env, unused); + if let Some(alt) = alternate { + null_unused_lvalues(alt, env, unused); + } + } + ReactiveTerminal::Switch { cases, .. } => { + for case in cases.iter_mut() { + if let Some(block) = &mut case.block { + null_unused_lvalues(block, env, unused); + } + } + } + ReactiveTerminal::Label { block, .. } => { + null_unused_lvalues(block, env, unused); + } + ReactiveTerminal::Try { block, handler, .. } => { + null_unused_lvalues(block, env, unused); + null_unused_lvalues(handler, env, unused); + } + } +} diff --git a/crates/react_compiler_reactive_scopes/src/prune_unused_scopes.rs b/crates/react_compiler_reactive_scopes/src/prune_unused_scopes.rs new file mode 100644 index 000000000000..d1bb4244f07c --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/prune_unused_scopes.rs @@ -0,0 +1,100 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! PruneUnusedScopes — converts scopes without outputs into regular blocks. +//! +//! Corresponds to `src/ReactiveScopes/PruneUnusedScopes.ts`. + +use react_compiler_hir::{ + environment::Environment, PrunedReactiveScopeBlock, ReactiveFunction, ReactiveScopeBlock, + ReactiveStatement, ReactiveTerminal, ReactiveTerminalStatement, +}; + +use crate::visitors::{transform_reactive_function, ReactiveFunctionTransform, Transformed}; + +struct State { + has_return_statement: bool, +} + +/// Converts scopes without outputs into pruned-scopes (regular blocks). +/// TS: `pruneUnusedScopes` +pub fn prune_unused_scopes( + func: &mut ReactiveFunction, + env: &Environment, +) -> Result<(), react_compiler_diagnostics::CompilerError> { + let mut transform = Transform { env }; + let mut state = State { + has_return_statement: false, + }; + transform_reactive_function(func, &mut transform, &mut state) +} + +struct Transform<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionTransform for Transform<'a> { + type State = State; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut State, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + self.traverse_terminal(stmt, state)?; + if matches!(stmt.terminal, ReactiveTerminal::Return { .. }) { + state.has_return_statement = true; + } + Ok(()) + } + + fn transform_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + _state: &mut State, + ) -> Result, react_compiler_diagnostics::CompilerError> { + let mut scope_state = State { + has_return_statement: false, + }; + self.visit_scope(scope, &mut scope_state)?; + + let scope_id = scope.scope; + let scope_data = &self.env.scopes[scope_id.0 as usize]; + + if !scope_state.has_return_statement + && scope_data.reassignments.is_empty() + && (scope_data.declarations.is_empty() || !has_own_declaration(scope_data, scope_id)) + { + // Replace with pruned scope + Ok(Transformed::Replace(ReactiveStatement::PrunedScope( + PrunedReactiveScopeBlock { + scope: scope.scope, + instructions: std::mem::take(&mut scope.instructions), + }, + ))) + } else { + Ok(Transformed::Keep) + } + } +} + +/// Does the scope block declare any values of its own? +/// Returns false if all declarations are propagated from nested scopes. +/// TS: `hasOwnDeclaration` +fn has_own_declaration( + scope_data: &react_compiler_hir::ReactiveScope, + scope_id: react_compiler_hir::ScopeId, +) -> bool { + for (_, decl) in &scope_data.declarations { + if decl.scope == scope_id { + return true; + } + } + false +} diff --git a/crates/react_compiler_reactive_scopes/src/rename_variables.rs b/crates/react_compiler_reactive_scopes/src/rename_variables.rs new file mode 100644 index 000000000000..dd3e10b58a13 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/rename_variables.rs @@ -0,0 +1,418 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! RenameVariables — renames variables for output, assigns unique names, +//! handles SSA renames. +//! +//! Corresponds to `src/ReactiveScopes/RenameVariables.ts`. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{ + environment::Environment, DeclarationId, EvaluationOrder, FunctionId, IdentifierName, + InstructionValue, ParamPattern, Place, PrunedReactiveScopeBlock, ReactiveBlock, + ReactiveFunction, ReactiveScopeBlock, ReactiveValue, +}; + +use crate::visitors::{self, ReactiveFunctionVisitor}; + +// ============================================================================= +// Scopes +// ============================================================================= + +struct Scopes { + seen: HashMap, + stack: Vec>, + globals: HashSet, + names: HashSet, +} + +impl Scopes { + fn new(globals: HashSet) -> Self { + Self { + seen: HashMap::new(), + stack: vec![HashMap::new()], + globals, + names: HashSet::new(), + } + } + + fn visit_identifier( + &mut self, + identifier_id: react_compiler_hir::IdentifierId, + env: &Environment, + ) { + let identifier = &env.identifiers[identifier_id.0 as usize]; + let original_name = match &identifier.name { + Some(name) => name.clone(), + None => return, + }; + let declaration_id = identifier.declaration_id; + + if self.seen.contains_key(&declaration_id) { + return; + } + + let original_value = original_name.value().to_string(); + let is_promoted = matches!(original_name, IdentifierName::Promoted(_)); + let is_promoted_temp = is_promoted && original_value.starts_with("#t"); + let is_promoted_jsx = is_promoted && original_value.starts_with("#T"); + + let mut name: String; + let mut id: u32 = 0; + if is_promoted_temp { + name = format!("t{}", id); + id += 1; + } else if is_promoted_jsx { + name = format!("T{}", id); + id += 1; + } else { + name = original_value.clone(); + } + + while self.lookup(&name).is_some() || self.globals.contains(&name) { + if is_promoted_temp { + name = format!("t{}", id); + id += 1; + } else if is_promoted_jsx { + name = format!("T{}", id); + id += 1; + } else { + name = format!("{}${}", original_value, id); + id += 1; + } + } + + let identifier_name = IdentifierName::Named(name.clone()); + self.seen.insert(declaration_id, identifier_name); + self.stack + .last_mut() + .unwrap() + .insert(name.clone(), declaration_id); + self.names.insert(name); + } + + fn lookup(&self, name: &str) -> Option { + for scope in self.stack.iter().rev() { + if let Some(id) = scope.get(name) { + return Some(*id); + } + } + None + } + + fn enter(&mut self) { + self.stack.push(HashMap::new()); + } + + fn leave(&mut self) { + self.stack.pop(); + } +} + +// ============================================================================= +// Visitor — TS: `class Visitor extends ReactiveFunctionVisitor` +// ============================================================================= + +struct Visitor<'a> { + env: &'a Environment, +} + +impl ReactiveFunctionVisitor for Visitor<'_> { + type State = Scopes; + + fn env(&self) -> &Environment { + self.env + } + + /// TS: `visitParam(place, state) { state.visit(place.identifier) }` + fn visit_param(&self, place: &Place, state: &mut Scopes) { + state.visit_identifier(place.identifier, self.env); + } + + /// TS: `visitLValue(_id, lvalue, state) { state.visit(lvalue.identifier) }` + fn visit_lvalue(&self, _id: EvaluationOrder, lvalue: &Place, state: &mut Scopes) { + state.visit_identifier(lvalue.identifier, self.env); + } + + /// TS: `visitPlace(_id, place, state) { state.visit(place.identifier) }` + fn visit_place(&self, _id: EvaluationOrder, place: &Place, state: &mut Scopes) { + state.visit_identifier(place.identifier, self.env); + } + + /// TS: `visitBlock(block, state) { state.enter(() => { + /// this.traverseBlock(block, state) }) }` + fn visit_block(&self, block: &ReactiveBlock, state: &mut Scopes) { + state.enter(); + self.traverse_block(block, state); + state.leave(); + } + + /// TS: `visitPrunedScope(scopeBlock, state) { + /// this.traverseBlock(scopeBlock.instructions, state) }` No enter/leave + /// — names assigned inside pruned scopes remain visible in + /// the enclosing scope, preventing name reuse. + fn visit_pruned_scope(&self, scope: &PrunedReactiveScopeBlock, state: &mut Scopes) { + self.traverse_block(&scope.instructions, state); + } + + /// TS: `visitScope(scope, state) { for (const [_, decl] of + /// scope.scope.declarations) state.visit(decl.identifier); + /// this.traverseScope(scope, state) }` + fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut Scopes) { + let scope_data = &self.env.scopes[scope.scope.0 as usize]; + let decl_ids: Vec = scope_data + .declarations + .iter() + .map(|(_, d)| d.identifier) + .collect(); + for id in decl_ids { + state.visit_identifier(id, self.env); + } + self.traverse_scope(scope, state); + } + + /// TS: `visitValue(id, value, state) { this.traverseValue(id, value, + /// state); if (value.kind === 'FunctionExpression' || value.kind === + /// 'ObjectMethod') this.visitHirFunction(value.loweredFunc.func, state) }` + fn visit_value(&self, id: EvaluationOrder, value: &ReactiveValue, state: &mut Scopes) { + self.traverse_value(id, value, state); + if let ReactiveValue::Instruction(iv) = value { + match iv { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + self.visit_hir_function(lowered_func.func, state); + } + _ => {} + } + } + } +} + +// ============================================================================= +// Public entry point +// ============================================================================= + +/// Renames variables for output — assigns unique names, handles SSA renames. +/// Returns a Set of all unique variable names used. +/// TS: `renameVariables` +pub fn rename_variables(func: &mut ReactiveFunction, env: &mut Environment) -> HashSet { + let globals = collect_referenced_globals(&func.body, env); + + // Phase 1: Use ReactiveFunctionVisitor to compute the rename mapping. + // This collects DeclarationId -> IdentifierName without mutating env. + let mut scopes = Scopes::new(globals.clone()); + rename_variables_impl(func, &Visitor { env }, &mut scopes); + + // Phase 2: Apply the computed renames to all identifiers in env. + for identifier in env.identifiers.iter_mut() { + if let Some(mapped_name) = scopes.seen.get(&identifier.declaration_id) { + if identifier.name.is_some() { + identifier.name = Some(mapped_name.clone()); + } + } + } + + let mut result: HashSet = scopes.names; + result.extend(globals); + result +} + +/// TS: `renameVariablesImpl` +fn rename_variables_impl(func: &ReactiveFunction, visitor: &Visitor, scopes: &mut Scopes) { + scopes.enter(); + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + visitor.visit_param(place, scopes); + } + visitors::visit_reactive_function(func, visitor, scopes); + scopes.leave(); +} + +// ============================================================================= +// CollectReferencedGlobals +// ============================================================================= + +/// Collects all globally referenced names from the reactive function. +/// TS: `collectReferencedGlobals` +fn collect_referenced_globals(block: &ReactiveBlock, env: &Environment) -> HashSet { + let mut globals = HashSet::new(); + collect_globals_block(block, &mut globals, env); + globals +} + +fn collect_globals_block(block: &ReactiveBlock, globals: &mut HashSet, env: &Environment) { + for stmt in block { + match stmt { + react_compiler_hir::ReactiveStatement::Instruction(instr) => { + collect_globals_value(&instr.value, globals, env); + } + react_compiler_hir::ReactiveStatement::Scope(scope) => { + collect_globals_block(&scope.instructions, globals, env); + } + react_compiler_hir::ReactiveStatement::PrunedScope(scope) => { + collect_globals_block(&scope.instructions, globals, env); + } + react_compiler_hir::ReactiveStatement::Terminal(terminal) => { + collect_globals_terminal(terminal, globals, env); + } + } + } +} + +fn collect_globals_value(value: &ReactiveValue, globals: &mut HashSet, env: &Environment) { + match value { + ReactiveValue::Instruction(iv) => { + if let InstructionValue::LoadGlobal { binding, .. } = iv { + globals.insert(binding.name().to_string()); + } + // Visit inner functions + match iv { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + collect_globals_hir_function(lowered_func.func, globals, env); + } + _ => {} + } + } + ReactiveValue::SequenceExpression { + instructions, + value: inner, + .. + } => { + for instr in instructions { + collect_globals_value(&instr.value, globals, env); + } + collect_globals_value(inner, globals, env); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + collect_globals_value(test, globals, env); + collect_globals_value(consequent, globals, env); + collect_globals_value(alternate, globals, env); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + collect_globals_value(left, globals, env); + collect_globals_value(right, globals, env); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + collect_globals_value(inner, globals, env); + } + } +} + +/// Recursively collects LoadGlobal names from an inner HIR function. +fn collect_globals_hir_function( + func_id: FunctionId, + globals: &mut HashSet, + env: &Environment, +) { + let inner_func = &env.functions[func_id.0 as usize]; + let block_ids: Vec<_> = inner_func.body.blocks.keys().copied().collect(); + for block_id in block_ids { + let inner_func = &env.functions[func_id.0 as usize]; + let block = &inner_func.body.blocks[&block_id]; + for instr_id in &block.instructions { + let instr = &inner_func.instructions[instr_id.0 as usize]; + if let InstructionValue::LoadGlobal { binding, .. } = &instr.value { + globals.insert(binding.name().to_string()); + } + // Recurse into nested function expressions + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + collect_globals_hir_function(lowered_func.func, globals, env); + } + _ => {} + } + } + } +} + +fn collect_globals_terminal( + stmt: &react_compiler_hir::ReactiveTerminalStatement, + globals: &mut HashSet, + env: &Environment, +) { + match &stmt.terminal { + react_compiler_hir::ReactiveTerminal::Break { .. } + | react_compiler_hir::ReactiveTerminal::Continue { .. } => {} + react_compiler_hir::ReactiveTerminal::Return { .. } + | react_compiler_hir::ReactiveTerminal::Throw { .. } => {} + react_compiler_hir::ReactiveTerminal::For { + init, + test, + update, + loop_block, + .. + } => { + collect_globals_value(init, globals, env); + collect_globals_value(test, globals, env); + collect_globals_block(loop_block, globals, env); + if let Some(update) = update { + collect_globals_value(update, globals, env); + } + } + react_compiler_hir::ReactiveTerminal::ForOf { + init, + test, + loop_block, + .. + } => { + collect_globals_value(init, globals, env); + collect_globals_value(test, globals, env); + collect_globals_block(loop_block, globals, env); + } + react_compiler_hir::ReactiveTerminal::ForIn { + init, loop_block, .. + } => { + collect_globals_value(init, globals, env); + collect_globals_block(loop_block, globals, env); + } + react_compiler_hir::ReactiveTerminal::DoWhile { + loop_block, test, .. + } => { + collect_globals_block(loop_block, globals, env); + collect_globals_value(test, globals, env); + } + react_compiler_hir::ReactiveTerminal::While { + test, loop_block, .. + } => { + collect_globals_value(test, globals, env); + collect_globals_block(loop_block, globals, env); + } + react_compiler_hir::ReactiveTerminal::If { + consequent, + alternate, + .. + } => { + collect_globals_block(consequent, globals, env); + if let Some(alt) = alternate { + collect_globals_block(alt, globals, env); + } + } + react_compiler_hir::ReactiveTerminal::Switch { cases, .. } => { + for case in cases { + if let Some(block) = &case.block { + collect_globals_block(block, globals, env); + } + } + } + react_compiler_hir::ReactiveTerminal::Label { block, .. } => { + collect_globals_block(block, globals, env); + } + react_compiler_hir::ReactiveTerminal::Try { block, handler, .. } => { + collect_globals_block(block, globals, env); + collect_globals_block(handler, globals, env); + } + } +} diff --git a/crates/react_compiler_reactive_scopes/src/stabilize_block_ids.rs b/crates/react_compiler_reactive_scopes/src/stabilize_block_ids.rs new file mode 100644 index 000000000000..9f6f64d80c1d --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/stabilize_block_ids.rs @@ -0,0 +1,133 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! StabilizeBlockIds +//! +//! Rewrites block IDs to sequential values so that the output is deterministic +//! regardless of the order in which blocks were created. +//! +//! Corresponds to `src/ReactiveScopes/StabilizeBlockIds.ts`. + +use std::collections::HashMap; + +use indexmap::IndexSet; +use react_compiler_hir::{ + environment::Environment, BlockId, ReactiveFunction, ReactiveScopeBlock, ReactiveTerminal, + ReactiveTerminalStatement, +}; + +use crate::visitors::{ + transform_reactive_function, visit_reactive_function, ReactiveFunctionTransform, + ReactiveFunctionVisitor, +}; + +/// Rewrites block IDs to sequential values. +/// TS: `stabilizeBlockIds` +pub fn stabilize_block_ids(func: &mut ReactiveFunction, env: &mut Environment) { + // Pass 1: Collect referenced labels (preserving insertion order to match TS Set + // behavior) + let mut referenced: IndexSet = IndexSet::new(); + let collector = CollectReferencedLabels { env: &*env }; + visit_reactive_function(func, &collector, &mut referenced); + + // Build mappings: referenced block IDs -> sequential IDs (insertion-order + // deterministic) + let mut mappings: HashMap = HashMap::new(); + for block_id in &referenced { + let len = mappings.len() as u32; + mappings.entry(*block_id).or_insert(BlockId(len)); + } + + // Pass 2: Rewrite block IDs using ReactiveFunctionTransform + let mut rewriter = RewriteBlockIds { env }; + let _ = transform_reactive_function(func, &mut rewriter, &mut mappings); +} + +// ============================================================================= +// Pass 1: CollectReferencedLabels +// ============================================================================= + +struct CollectReferencedLabels<'a> { + env: &'a Environment, +} + +impl<'a> ReactiveFunctionVisitor for CollectReferencedLabels<'a> { + type State = IndexSet; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut Self::State) { + let scope_data = &self.env.scopes[scope.scope.0 as usize]; + if let Some(ref early_return) = scope_data.early_return_value { + state.insert(early_return.label); + } + self.traverse_scope(scope, state); + } + + fn visit_terminal(&self, stmt: &ReactiveTerminalStatement, state: &mut Self::State) { + if let Some(ref label) = stmt.label { + if !label.implicit { + state.insert(label.id); + } + } + self.traverse_terminal(stmt, state); + } +} + +// ============================================================================= +// Pass 2: RewriteBlockIds +// ============================================================================= + +fn get_or_insert_mapping(mappings: &mut HashMap, id: BlockId) -> BlockId { + let len = mappings.len() as u32; + *mappings.entry(id).or_insert(BlockId(len)) +} + +/// TS: `class RewriteBlockIds extends ReactiveFunctionVisitor>` +struct RewriteBlockIds<'a> { + env: &'a mut Environment, +} + +impl<'a> ReactiveFunctionTransform for RewriteBlockIds<'a> { + type State = HashMap; + + fn env(&self) -> &Environment { + self.env + } + + fn visit_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut Self::State, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + let scope_data = &mut self.env.scopes[scope.scope.0 as usize]; + if let Some(ref mut early_return) = scope_data.early_return_value { + early_return.label = get_or_insert_mapping(state, early_return.label); + } + self.traverse_scope(scope, state) + } + + fn visit_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut Self::State, + ) -> Result<(), react_compiler_diagnostics::CompilerError> { + if let Some(ref mut label) = stmt.label { + label.id = get_or_insert_mapping(state, label.id); + } + + match &mut stmt.terminal { + ReactiveTerminal::Break { target, .. } | ReactiveTerminal::Continue { target, .. } => { + *target = get_or_insert_mapping(state, *target); + } + _ => {} + } + + self.traverse_terminal(stmt, state) + } +} diff --git a/crates/react_compiler_reactive_scopes/src/visitors.rs b/crates/react_compiler_reactive_scopes/src/visitors.rs new file mode 100644 index 000000000000..8bb726ef7a15 --- /dev/null +++ b/crates/react_compiler_reactive_scopes/src/visitors.rs @@ -0,0 +1,838 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Visitor and transform traits for ReactiveFunction. +//! +//! Corresponds to `src/ReactiveScopes/visitors.ts` in the TypeScript compiler. + +use react_compiler_diagnostics::CompilerError; +use react_compiler_hir::{ + environment::Environment, EvaluationOrder, FunctionId, InstructionValue, ParamPattern, Place, + PrunedReactiveScopeBlock, ReactiveBlock, ReactiveFunction, ReactiveInstruction, + ReactiveScopeBlock, ReactiveStatement, ReactiveTerminal, ReactiveTerminalStatement, + ReactiveValue, +}; + +// ============================================================================= +// ReactiveFunctionVisitor trait +// ============================================================================= + +/// Visitor trait for walking a ReactiveFunction tree. +/// +/// Override individual `visit_*` methods to customize behavior; call the +/// corresponding `traverse_*` to continue the default recursion. +/// +/// TS: `class ReactiveFunctionVisitor` +pub trait ReactiveFunctionVisitor { + type State; + + /// Provide Environment access. The default traversal uses this to include + /// FunctionExpression/ObjectMethod context places as operands (matching the + /// TS `eachInstructionValueOperand` behavior). + fn env(&self) -> &Environment; + + fn visit_id(&self, _id: EvaluationOrder, _state: &mut Self::State) {} + + fn visit_place(&self, _id: EvaluationOrder, _place: &Place, _state: &mut Self::State) {} + + fn visit_lvalue(&self, _id: EvaluationOrder, _lvalue: &Place, _state: &mut Self::State) {} + + fn visit_param(&self, _place: &Place, _state: &mut Self::State) {} + + /// Walk an inner HIR function, visiting params, instructions (with lvalues, + /// value-lvalues, operands, and nested functions), and terminal operands. + /// TS: `visitHirFunction` + fn visit_hir_function(&self, func_id: FunctionId, state: &mut Self::State) { + let inner_func = &self.env().functions[func_id.0 as usize]; + for param in &inner_func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + self.visit_param(place, state); + } + let block_ids: Vec<_> = inner_func.body.blocks.keys().copied().collect(); + for block_id in block_ids { + let inner_func = &self.env().functions[func_id.0 as usize]; + let block = &inner_func.body.blocks[&block_id]; + let instr_ids: Vec<_> = block.instructions.clone(); + let terminal_operands: Vec = + react_compiler_hir::visitors::each_terminal_operand(&block.terminal); + let terminal_id = block.terminal.evaluation_order(); + + for instr_id in &instr_ids { + let inner_func = &self.env().functions[func_id.0 as usize]; + let instr = &inner_func.instructions[instr_id.0 as usize]; + // Build a temporary ReactiveInstruction for the visitor + let reactive_instr = ReactiveInstruction { + id: instr.id, + lvalue: Some(instr.lvalue.clone()), + value: ReactiveValue::Instruction(instr.value.clone()), + effects: None, + loc: instr.loc, + }; + self.visit_instruction(&reactive_instr, state); + // Recurse into nested functions + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + self.visit_hir_function(lowered_func.func, state); + } + _ => {} + } + } + for operand in &terminal_operands { + self.visit_place(terminal_id, operand, state); + } + } + } + + fn visit_value(&self, id: EvaluationOrder, value: &ReactiveValue, state: &mut Self::State) { + self.traverse_value(id, value, state); + } + + fn traverse_value(&self, id: EvaluationOrder, value: &ReactiveValue, state: &mut Self::State) { + match value { + ReactiveValue::OptionalExpression { value: inner, .. } => { + self.visit_value(id, inner, state); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + self.visit_value(id, left, state); + self.visit_value(id, right, state); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + self.visit_value(id, test, state); + self.visit_value(id, consequent, state); + self.visit_value(id, alternate, state); + } + ReactiveValue::SequenceExpression { + instructions, + id: seq_id, + value: inner, + .. + } => { + for instr in instructions { + self.visit_instruction(instr, state); + } + self.visit_value(*seq_id, inner, state); + } + ReactiveValue::Instruction(instr_value) => { + let operands = react_compiler_hir::visitors::each_instruction_value_operand( + instr_value, + self.env(), + ); + for place in &operands { + self.visit_place(id, place, state); + } + } + } + } + + fn visit_instruction(&self, instruction: &ReactiveInstruction, state: &mut Self::State) { + self.traverse_instruction(instruction, state); + } + + fn traverse_instruction(&self, instruction: &ReactiveInstruction, state: &mut Self::State) { + self.visit_id(instruction.id, state); + // Visit instruction-level lvalue + if let Some(lvalue) = &instruction.lvalue { + self.visit_lvalue(instruction.id, lvalue, state); + } + // Visit value-level lvalues (TS: eachInstructionValueLValue) + if let ReactiveValue::Instruction(iv) = &instruction.value { + for place in react_compiler_hir::visitors::each_instruction_value_lvalue(iv) { + self.visit_lvalue(instruction.id, &place, state); + } + } + self.visit_value(instruction.id, &instruction.value, state); + } + + fn visit_terminal(&self, stmt: &ReactiveTerminalStatement, state: &mut Self::State) { + self.traverse_terminal(stmt, state); + } + + fn traverse_terminal(&self, stmt: &ReactiveTerminalStatement, state: &mut Self::State) { + let terminal = &stmt.terminal; + let id = terminal_id(terminal); + self.visit_id(id, state); + match terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { value, id, .. } => { + self.visit_place(*id, value, state); + } + ReactiveTerminal::Throw { value, id, .. } => { + self.visit_place(*id, value, state); + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + id, + .. + } => { + self.visit_value(*id, init, state); + self.visit_value(*id, test, state); + self.visit_block(loop_block, state); + if let Some(update) = update { + self.visit_value(*id, update, state); + } + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + id, + .. + } => { + self.visit_value(*id, init, state); + self.visit_value(*id, test, state); + self.visit_block(loop_block, state); + } + ReactiveTerminal::ForIn { + init, + loop_block, + id, + .. + } => { + self.visit_value(*id, init, state); + self.visit_block(loop_block, state); + } + ReactiveTerminal::DoWhile { + loop_block, + test, + id, + .. + } => { + self.visit_block(loop_block, state); + self.visit_value(*id, test, state); + } + ReactiveTerminal::While { + test, + loop_block, + id, + .. + } => { + self.visit_value(*id, test, state); + self.visit_block(loop_block, state); + } + ReactiveTerminal::If { + test, + consequent, + alternate, + id, + .. + } => { + self.visit_place(*id, test, state); + self.visit_block(consequent, state); + if let Some(alt) = alternate { + self.visit_block(alt, state); + } + } + ReactiveTerminal::Switch { + test, cases, id, .. + } => { + self.visit_place(*id, test, state); + for case in cases { + if let Some(t) = &case.test { + self.visit_place(*id, t, state); + } + if let Some(block) = &case.block { + self.visit_block(block, state); + } + } + } + ReactiveTerminal::Label { block, .. } => { + self.visit_block(block, state); + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + id, + .. + } => { + self.visit_block(block, state); + if let Some(binding) = handler_binding { + self.visit_place(*id, binding, state); + } + self.visit_block(handler, state); + } + } + } + + fn visit_scope(&self, scope: &ReactiveScopeBlock, state: &mut Self::State) { + self.traverse_scope(scope, state); + } + + fn traverse_scope(&self, scope: &ReactiveScopeBlock, state: &mut Self::State) { + self.visit_block(&scope.instructions, state); + } + + fn visit_pruned_scope(&self, scope: &PrunedReactiveScopeBlock, state: &mut Self::State) { + self.traverse_pruned_scope(scope, state); + } + + fn traverse_pruned_scope(&self, scope: &PrunedReactiveScopeBlock, state: &mut Self::State) { + self.visit_block(&scope.instructions, state); + } + + fn visit_block(&self, block: &ReactiveBlock, state: &mut Self::State) { + self.traverse_block(block, state); + } + + fn traverse_block(&self, block: &ReactiveBlock, state: &mut Self::State) { + for stmt in block { + match stmt { + ReactiveStatement::Instruction(instr) => { + self.visit_instruction(instr, state); + } + ReactiveStatement::Scope(scope) => { + self.visit_scope(scope, state); + } + ReactiveStatement::PrunedScope(scope) => { + self.visit_pruned_scope(scope, state); + } + ReactiveStatement::Terminal(terminal) => { + self.visit_terminal(terminal, state); + } + } + } + } +} + +/// Entry point for visiting a reactive function. +/// TS: `visitReactiveFunction` +pub fn visit_reactive_function( + func: &ReactiveFunction, + visitor: &V, + state: &mut V::State, +) { + visitor.visit_block(&func.body, state); +} + +// ============================================================================= +// Transformed / TransformedValue enums +// ============================================================================= + +/// Result of transforming a ReactiveStatement. +/// TS: `Transformed` +pub enum Transformed { + Keep, + Remove, + Replace(T), + ReplaceMany(Vec), +} + +/// Result of transforming a ReactiveValue. +/// TS: `TransformedValue` +#[allow(dead_code)] +pub enum TransformedValue { + Keep, + Replace(ReactiveValue), +} + +// ============================================================================= +// ReactiveFunctionTransform trait +// ============================================================================= + +/// Transform trait for modifying a ReactiveFunction tree in-place. +/// +/// Extends the visitor pattern with `transform_*` methods that can modify +/// or remove statements. The `traverse_block` implementation handles applying +/// transform results to the block. +/// +/// TS: `class ReactiveFunctionTransform` +pub trait ReactiveFunctionTransform { + type State; + + /// Provide Environment access. The default traversal uses this to include + /// FunctionExpression/ObjectMethod context places as operands (matching the + /// TS `eachInstructionValueOperand` behavior). + fn env(&self) -> &Environment; + + fn visit_id( + &mut self, + _id: EvaluationOrder, + _state: &mut Self::State, + ) -> Result<(), CompilerError> { + Ok(()) + } + + fn visit_place( + &mut self, + _id: EvaluationOrder, + _place: &Place, + _state: &mut Self::State, + ) -> Result<(), CompilerError> { + Ok(()) + } + + fn visit_lvalue( + &mut self, + _id: EvaluationOrder, + _lvalue: &Place, + _state: &mut Self::State, + ) -> Result<(), CompilerError> { + Ok(()) + } + + fn visit_value( + &mut self, + id: EvaluationOrder, + value: &mut ReactiveValue, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.traverse_value(id, value, state) + } + + fn traverse_value( + &mut self, + id: EvaluationOrder, + value: &mut ReactiveValue, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + match value { + ReactiveValue::OptionalExpression { value: inner, .. } => { + let next = self.transform_value(id, inner, state)?; + if let TransformedValue::Replace(new_value) = next { + **inner = new_value; + } + } + ReactiveValue::LogicalExpression { left, right, .. } => { + let next_left = self.transform_value(id, left, state)?; + if let TransformedValue::Replace(new_value) = next_left { + **left = new_value; + } + let next_right = self.transform_value(id, right, state)?; + if let TransformedValue::Replace(new_value) = next_right { + **right = new_value; + } + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + let next_test = self.transform_value(id, test, state)?; + if let TransformedValue::Replace(new_value) = next_test { + **test = new_value; + } + let next_cons = self.transform_value(id, consequent, state)?; + if let TransformedValue::Replace(new_value) = next_cons { + **consequent = new_value; + } + let next_alt = self.transform_value(id, alternate, state)?; + if let TransformedValue::Replace(new_value) = next_alt { + **alternate = new_value; + } + } + ReactiveValue::SequenceExpression { + instructions, + id: seq_id, + value: inner, + .. + } => { + let seq_id = *seq_id; + for instr in instructions.iter_mut() { + self.visit_instruction(instr, state)?; + } + let next = self.transform_value(seq_id, inner, state)?; + if let TransformedValue::Replace(new_value) = next { + **inner = new_value; + } + } + ReactiveValue::Instruction(instr_value) => { + // Collect operands before visiting to avoid borrow conflict + // (self.env() borrows self immutably, self.visit_place() needs &mut self). + let operands = react_compiler_hir::visitors::each_instruction_value_operand( + instr_value, + self.env(), + ); + for place in &operands { + self.visit_place(id, place, state)?; + } + } + } + Ok(()) + } + + fn visit_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.traverse_instruction(instruction, state) + } + + fn transform_value( + &mut self, + id: EvaluationOrder, + value: &mut ReactiveValue, + state: &mut Self::State, + ) -> Result { + self.visit_value(id, value, state)?; + Ok(TransformedValue::Keep) + } + + fn traverse_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.visit_id(instruction.id, state)?; + // Visit instruction-level lvalue + if let Some(lvalue) = &instruction.lvalue { + self.visit_lvalue(instruction.id, lvalue, state)?; + } + // Visit value-level lvalues (TS: eachInstructionValueLValue) + if let ReactiveValue::Instruction(iv) = &instruction.value { + for place in react_compiler_hir::visitors::each_instruction_value_lvalue(iv) { + self.visit_lvalue(instruction.id, &place, state)?; + } + } + let next_value = self.transform_value(instruction.id, &mut instruction.value, state)?; + if let TransformedValue::Replace(new_value) = next_value { + instruction.value = new_value; + } + Ok(()) + } + + fn visit_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.traverse_terminal(stmt, state) + } + + fn traverse_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + let terminal = &mut stmt.terminal; + let id = terminal_id(terminal); + self.visit_id(id, state)?; + match terminal { + ReactiveTerminal::Break { .. } | ReactiveTerminal::Continue { .. } => {} + ReactiveTerminal::Return { value, id, .. } => { + self.visit_place(*id, value, state)?; + } + ReactiveTerminal::Throw { value, id, .. } => { + self.visit_place(*id, value, state)?; + } + ReactiveTerminal::For { + init, + test, + update, + loop_block, + id, + .. + } => { + let id = *id; + let next_init = self.transform_value(id, init, state)?; + if let TransformedValue::Replace(new_value) = next_init { + *init = new_value; + } + let next_test = self.transform_value(id, test, state)?; + if let TransformedValue::Replace(new_value) = next_test { + *test = new_value; + } + if let Some(update) = update { + let next_update = self.transform_value(id, update, state)?; + if let TransformedValue::Replace(new_value) = next_update { + *update = new_value; + } + } + self.visit_block(loop_block, state)?; + } + ReactiveTerminal::ForOf { + init, + test, + loop_block, + id, + .. + } => { + let id = *id; + let next_init = self.transform_value(id, init, state)?; + if let TransformedValue::Replace(new_value) = next_init { + *init = new_value; + } + let next_test = self.transform_value(id, test, state)?; + if let TransformedValue::Replace(new_value) = next_test { + *test = new_value; + } + self.visit_block(loop_block, state)?; + } + ReactiveTerminal::ForIn { + init, + loop_block, + id, + .. + } => { + let id = *id; + let next_init = self.transform_value(id, init, state)?; + if let TransformedValue::Replace(new_value) = next_init { + *init = new_value; + } + self.visit_block(loop_block, state)?; + } + ReactiveTerminal::DoWhile { + loop_block, + test, + id, + .. + } => { + let id = *id; + self.visit_block(loop_block, state)?; + let next_test = self.transform_value(id, test, state)?; + if let TransformedValue::Replace(new_value) = next_test { + *test = new_value; + } + } + ReactiveTerminal::While { + test, + loop_block, + id, + .. + } => { + let id = *id; + let next_test = self.transform_value(id, test, state)?; + if let TransformedValue::Replace(new_value) = next_test { + *test = new_value; + } + self.visit_block(loop_block, state)?; + } + ReactiveTerminal::If { + test, + consequent, + alternate, + id, + .. + } => { + self.visit_place(*id, test, state)?; + self.visit_block(consequent, state)?; + if let Some(alt) = alternate { + self.visit_block(alt, state)?; + } + } + ReactiveTerminal::Switch { + test, cases, id, .. + } => { + let id = *id; + self.visit_place(id, test, state)?; + for case in cases.iter_mut() { + if let Some(t) = &case.test { + self.visit_place(id, t, state)?; + } + if let Some(block) = &mut case.block { + self.visit_block(block, state)?; + } + } + } + ReactiveTerminal::Label { block, .. } => { + self.visit_block(block, state)?; + } + ReactiveTerminal::Try { + block, + handler_binding, + handler, + id, + .. + } => { + let id = *id; + self.visit_block(block, state)?; + if let Some(binding) = handler_binding { + self.visit_place(id, binding, state)?; + } + self.visit_block(handler, state)?; + } + } + Ok(()) + } + + fn visit_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.traverse_scope(scope, state) + } + + fn traverse_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.visit_block(&mut scope.instructions, state) + } + + fn visit_pruned_scope( + &mut self, + scope: &mut PrunedReactiveScopeBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.traverse_pruned_scope(scope, state) + } + + fn traverse_pruned_scope( + &mut self, + scope: &mut PrunedReactiveScopeBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.visit_block(&mut scope.instructions, state) + } + + fn visit_block( + &mut self, + block: &mut ReactiveBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + self.traverse_block(block, state) + } + + fn transform_instruction( + &mut self, + instruction: &mut ReactiveInstruction, + state: &mut Self::State, + ) -> Result, CompilerError> { + self.visit_instruction(instruction, state)?; + Ok(Transformed::Keep) + } + + fn transform_terminal( + &mut self, + stmt: &mut ReactiveTerminalStatement, + state: &mut Self::State, + ) -> Result, CompilerError> { + self.visit_terminal(stmt, state)?; + Ok(Transformed::Keep) + } + + fn transform_scope( + &mut self, + scope: &mut ReactiveScopeBlock, + state: &mut Self::State, + ) -> Result, CompilerError> { + self.visit_scope(scope, state)?; + Ok(Transformed::Keep) + } + + fn transform_pruned_scope( + &mut self, + scope: &mut PrunedReactiveScopeBlock, + state: &mut Self::State, + ) -> Result, CompilerError> { + self.visit_pruned_scope(scope, state)?; + Ok(Transformed::Keep) + } + + fn traverse_block( + &mut self, + block: &mut ReactiveBlock, + state: &mut Self::State, + ) -> Result<(), CompilerError> { + let mut next_block: Option> = None; + let len = block.len(); + for i in 0..len { + // Take the statement out temporarily + let mut stmt = std::mem::replace( + &mut block[i], + // Placeholder — will be overwritten or discarded + ReactiveStatement::Instruction(ReactiveInstruction { + id: EvaluationOrder(0), + lvalue: None, + value: ReactiveValue::Instruction( + react_compiler_hir::InstructionValue::Debugger { loc: None }, + ), + effects: None, + loc: None, + }), + ); + let transformed = match &mut stmt { + ReactiveStatement::Instruction(instr) => { + self.transform_instruction(instr, state)? + } + ReactiveStatement::Scope(scope) => self.transform_scope(scope, state)?, + ReactiveStatement::PrunedScope(scope) => { + self.transform_pruned_scope(scope, state)? + } + ReactiveStatement::Terminal(terminal) => { + self.transform_terminal(terminal, state)? + } + }; + match transformed { + Transformed::Keep => { + if let Some(ref mut nb) = next_block { + nb.push(stmt); + } else { + // Put it back + block[i] = stmt; + } + } + Transformed::Remove => { + if next_block.is_none() { + next_block = Some(block[..i].to_vec()); + } + } + Transformed::Replace(replacement) => { + if next_block.is_none() { + next_block = Some(block[..i].to_vec()); + } + next_block.as_mut().unwrap().push(replacement); + } + Transformed::ReplaceMany(replacements) => { + if next_block.is_none() { + next_block = Some(block[..i].to_vec()); + } + next_block.as_mut().unwrap().extend(replacements); + } + } + } + if let Some(nb) = next_block { + *block = nb; + } + Ok(()) + } +} + +/// Entry point for transforming a reactive function. +/// TS: `visitReactiveFunction` (used with transforms too) +pub fn transform_reactive_function( + func: &mut ReactiveFunction, + transform: &mut T, + state: &mut T::State, +) -> Result<(), CompilerError> { + transform.visit_block(&mut func.body, state) +} + +// ============================================================================= +// Helper: extract terminal ID +// ============================================================================= + +fn terminal_id(terminal: &ReactiveTerminal) -> EvaluationOrder { + match terminal { + ReactiveTerminal::Break { id, .. } + | ReactiveTerminal::Continue { id, .. } + | ReactiveTerminal::Return { id, .. } + | ReactiveTerminal::Throw { id, .. } + | ReactiveTerminal::Switch { id, .. } + | ReactiveTerminal::DoWhile { id, .. } + | ReactiveTerminal::While { id, .. } + | ReactiveTerminal::For { id, .. } + | ReactiveTerminal::ForOf { id, .. } + | ReactiveTerminal::ForIn { id, .. } + | ReactiveTerminal::If { id, .. } + | ReactiveTerminal::Label { id, .. } + | ReactiveTerminal::Try { id, .. } => *id, + } +} diff --git a/crates/react_compiler_ssa/Cargo.toml b/crates/react_compiler_ssa/Cargo.toml new file mode 100644 index 000000000000..906e333b20e2 --- /dev/null +++ b/crates/react_compiler_ssa/Cargo.toml @@ -0,0 +1,13 @@ +[package] +description = "Vendored React Compiler SSA passes from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_ssa" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } +react_compiler_lowering = { path = "../react_compiler_lowering" } +indexmap = { workspace = true } diff --git a/crates/react_compiler_ssa/src/eliminate_redundant_phi.rs b/crates/react_compiler_ssa/src/eliminate_redundant_phi.rs new file mode 100644 index 000000000000..4ed8bf9643dd --- /dev/null +++ b/crates/react_compiler_ssa/src/eliminate_redundant_phi.rs @@ -0,0 +1,155 @@ +use std::collections::{HashMap, HashSet}; + +use react_compiler_hir::{environment::Environment, visitors, *}; + +use crate::enter_ssa::placeholder_function; + +// ============================================================================= +// Helper: rewrite_place +// ============================================================================= + +fn rewrite_place(place: &mut Place, rewrites: &HashMap) { + if let Some(&rewrite) = rewrites.get(&place.identifier) { + place.identifier = rewrite; + } +} + +// ============================================================================= +// Public entry point +// ============================================================================= + +pub fn eliminate_redundant_phi(func: &mut HirFunction, env: &mut Environment) { + let mut rewrites: HashMap = HashMap::new(); + eliminate_redundant_phi_impl(func, env, &mut rewrites); +} + +// ============================================================================= +// Inner implementation +// ============================================================================= + +fn eliminate_redundant_phi_impl( + func: &mut HirFunction, + env: &mut Environment, + rewrites: &mut HashMap, +) { + let ir = &mut func.body; + + let mut has_back_edge = false; + let mut visited: HashSet = HashSet::new(); + + let mut size; + loop { + size = rewrites.len(); + + let block_ids: Vec = ir.blocks.keys().copied().collect(); + for block_id in &block_ids { + let block_id = *block_id; + + if !has_back_edge { + let block = ir.blocks.get(&block_id).unwrap(); + for pred_id in &block.preds { + if !visited.contains(pred_id) { + has_back_edge = true; + } + } + } + visited.insert(block_id); + + // Find any redundant phis: rewrite operands, identify redundant phis, remove + // them. Matches TS behavior: each phi's operands are rewritten + // before checking redundancy, so that rewrites from earlier phis in + // the same block are visible to later phis. + let block = ir.blocks.get_mut(&block_id).unwrap(); + block.phis.retain_mut(|phi| { + // Remap phis in case operands are from eliminated phis + for (_, operand) in phi.operands.iter_mut() { + rewrite_place(operand, rewrites); + } + + // Find if the phi can be eliminated + let mut same: Option = None; + let mut is_redundant = true; + for (_, operand) in &phi.operands { + if (same.is_some() && operand.identifier == same.unwrap()) + || operand.identifier == phi.place.identifier + { + continue; + } else if same.is_some() { + is_redundant = false; + break; + } else { + same = Some(operand.identifier); + } + } + if is_redundant { + let same = same.expect("Expected phis to be non-empty"); + rewrites.insert(phi.place.identifier, same); + false // remove this phi + } else { + true // keep this phi + } + }); + + // Rewrite instructions + let instruction_ids: Vec = + ir.blocks.get(&block_id).unwrap().instructions.clone(); + + for instr_id in &instruction_ids { + let instr_idx = instr_id.0 as usize; + let instr = &mut func.instructions[instr_idx]; + + // Rewrite all lvalues (matches TS eachInstructionLValue) + rewrite_place(&mut instr.lvalue, rewrites); + visitors::for_each_instruction_value_lvalue_mut(&mut instr.value, &mut |place| { + rewrite_place(place, rewrites); + }); + + // Rewrite operands using canonical visitor + visitors::for_each_instruction_value_operand_mut( + &mut func.instructions[instr_idx].value, + &mut |place| { + rewrite_place(place, rewrites); + }, + ); + + // Handle FunctionExpression/ObjectMethod context and recursion + let instr = &func.instructions[instr_idx]; + let func_expr_id = match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + Some(lowered_func.func) + } + _ => None, + }; + + if let Some(fid) = func_expr_id { + // Rewrite context places + let context = &mut env.functions[fid.0 as usize].context; + for place in context.iter_mut() { + rewrite_place(place, rewrites); + } + + // Take inner function out, process it, put it back + let mut inner_func = std::mem::replace( + &mut env.functions[fid.0 as usize], + placeholder_function(), + ); + + eliminate_redundant_phi_impl(&mut inner_func, env, rewrites); + + env.functions[fid.0 as usize] = inner_func; + } + } + + // Rewrite terminal operands using canonical visitor + let terminal = &mut ir.blocks.get_mut(&block_id).unwrap().terminal; + visitors::for_each_terminal_operand_mut(terminal, &mut |place| { + rewrite_place(place, rewrites); + }); + } + + if !(rewrites.len() > size && has_back_edge) { + break; + } + } +} diff --git a/crates/react_compiler_ssa/src/enter_ssa.rs b/crates/react_compiler_ssa/src/enter_ssa.rs new file mode 100644 index 000000000000..e145d50ec7e6 --- /dev/null +++ b/crates/react_compiler_ssa/src/enter_ssa.rs @@ -0,0 +1,509 @@ +use std::collections::{HashMap, HashSet}; + +use indexmap::IndexMap; +use react_compiler_diagnostics::{CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory}; +use react_compiler_hir::{environment::Environment, visitors, *}; + +// ============================================================================= +// SSABuilder +// ============================================================================= + +struct IncompletePhi { + old_place: Place, + new_place: Place, +} + +struct State { + defs: HashMap, + incomplete_phis: Vec, +} + +struct SSABuilder { + states: HashMap, + current: Option, + unsealed_preds: HashMap, + block_preds: HashMap>, + unknown: HashSet, + context: HashSet, + pending_phis: HashMap>, + processed_functions: Vec, +} + +impl SSABuilder { + fn new(blocks: &IndexMap) -> Self { + let mut block_preds = HashMap::new(); + for (id, block) in blocks { + block_preds.insert(*id, block.preds.iter().copied().collect()); + } + SSABuilder { + states: HashMap::new(), + current: None, + unsealed_preds: HashMap::new(), + block_preds, + unknown: HashSet::new(), + context: HashSet::new(), + pending_phis: HashMap::new(), + processed_functions: Vec::new(), + } + } + + fn define_function(&mut self, func: &HirFunction) { + for (id, block) in &func.body.blocks { + self.block_preds + .insert(*id, block.preds.iter().copied().collect()); + } + } + + fn state_mut(&mut self) -> &mut State { + let current = self + .current + .expect("we need to be in a block to access state!"); + self.states + .get_mut(¤t) + .expect("state not found for current block") + } + + fn make_id(&mut self, old_id: IdentifierId, env: &mut Environment) -> IdentifierId { + let new_id = env.next_identifier_id(); + let old = &env.identifiers[old_id.0 as usize]; + let declaration_id = old.declaration_id; + let name = old.name.clone(); + let loc = old.loc; + let new_ident = &mut env.identifiers[new_id.0 as usize]; + new_ident.declaration_id = declaration_id; + new_ident.name = name; + new_ident.loc = loc; + new_id + } + + fn define_place( + &mut self, + old_place: &Place, + env: &mut Environment, + ) -> Result { + let old_id = old_place.identifier; + + if self.unknown.contains(&old_id) { + let ident = &env.identifiers[old_id.0 as usize]; + let name = match &ident.name { + Some(name) => format!("{}${}", name.value(), old_id.0), + None => format!("${}", old_id.0), + }; + return Err(CompilerDiagnostic::new( + ErrorCategory::Todo, + "[hoisting] EnterSSA: Expected identifier to be defined before being used", + Some(format!("Identifier {} is undefined", name)), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: old_place.loc, + message: None, + identifier_name: None, + })); + } + + // Do not redefine context references. + if self.context.contains(&old_id) { + return Ok(self.get_place(old_place, env)); + } + + let new_id = self.make_id(old_id, env); + self.state_mut().defs.insert(old_id, new_id); + Ok(Place { + identifier: new_id, + effect: old_place.effect, + reactive: old_place.reactive, + loc: old_place.loc, + }) + } + + #[allow(dead_code)] + fn define_context( + &mut self, + old_place: &Place, + env: &mut Environment, + ) -> Result { + let old_id = old_place.identifier; + let new_place = self.define_place(old_place, env)?; + self.context.insert(old_id); + Ok(new_place) + } + + fn get_place(&mut self, old_place: &Place, env: &mut Environment) -> Place { + let current_id = self.current.expect("must be in a block"); + let new_id = self.get_id_at(old_place, current_id, env); + Place { + identifier: new_id, + effect: old_place.effect, + reactive: old_place.reactive, + loc: old_place.loc, + } + } + + fn get_id_at( + &mut self, + old_place: &Place, + block_id: BlockId, + env: &mut Environment, + ) -> IdentifierId { + if let Some(state) = self.states.get(&block_id) { + if let Some(&new_id) = state.defs.get(&old_place.identifier) { + return new_id; + } + } + + let preds = self.block_preds.get(&block_id).cloned().unwrap_or_default(); + + if preds.is_empty() { + self.unknown.insert(old_place.identifier); + return old_place.identifier; + } + + let unsealed = self.unsealed_preds.get(&block_id).copied().unwrap_or(0); + if unsealed > 0 { + let new_id = self.make_id(old_place.identifier, env); + let new_place = Place { + identifier: new_id, + effect: old_place.effect, + reactive: old_place.reactive, + loc: old_place.loc, + }; + let state = self.states.get_mut(&block_id).unwrap(); + state.incomplete_phis.push(IncompletePhi { + old_place: old_place.clone(), + new_place, + }); + state.defs.insert(old_place.identifier, new_id); + return new_id; + } + + if preds.len() == 1 { + let pred = preds[0]; + let new_id = self.get_id_at(old_place, pred, env); + self.states + .get_mut(&block_id) + .unwrap() + .defs + .insert(old_place.identifier, new_id); + return new_id; + } + + let new_id = self.make_id(old_place.identifier, env); + self.states + .get_mut(&block_id) + .unwrap() + .defs + .insert(old_place.identifier, new_id); + let new_place = Place { + identifier: new_id, + effect: old_place.effect, + reactive: old_place.reactive, + loc: old_place.loc, + }; + self.add_phi(block_id, old_place, &new_place, env); + new_id + } + + fn add_phi( + &mut self, + block_id: BlockId, + old_place: &Place, + new_place: &Place, + env: &mut Environment, + ) { + let preds = self.block_preds.get(&block_id).cloned().unwrap_or_default(); + + let mut pred_defs: IndexMap = IndexMap::new(); + for pred_block_id in &preds { + let pred_id = self.get_id_at(old_place, *pred_block_id, env); + pred_defs.insert( + *pred_block_id, + Place { + identifier: pred_id, + effect: old_place.effect, + reactive: old_place.reactive, + loc: old_place.loc, + }, + ); + } + + let phi = Phi { + place: new_place.clone(), + operands: pred_defs, + }; + + self.pending_phis.entry(block_id).or_default().push(phi); + } + + fn fix_incomplete_phis(&mut self, block_id: BlockId, env: &mut Environment) { + let incomplete_phis: Vec = self + .states + .get_mut(&block_id) + .unwrap() + .incomplete_phis + .drain(..) + .collect(); + for phi in &incomplete_phis { + self.add_phi(block_id, &phi.old_place, &phi.new_place, env); + } + } + + fn start_block(&mut self, block_id: BlockId) { + self.current = Some(block_id); + self.states.insert( + block_id, + State { + defs: HashMap::new(), + incomplete_phis: Vec::new(), + }, + ); + } +} + +// ============================================================================= +// Public entry point +// ============================================================================= + +pub fn enter_ssa(func: &mut HirFunction, env: &mut Environment) -> Result<(), CompilerDiagnostic> { + let mut builder = SSABuilder::new(&func.body.blocks); + let root_entry = func.body.entry; + enter_ssa_impl(func, &mut builder, env, root_entry)?; + + // Apply all pending phis to the actual blocks + apply_pending_phis(func, env, &mut builder); + + Ok(()) +} + +fn apply_pending_phis(func: &mut HirFunction, env: &mut Environment, builder: &mut SSABuilder) { + for (block_id, block) in func.body.blocks.iter_mut() { + if let Some(phis) = builder.pending_phis.remove(block_id) { + block.phis.extend(phis); + } + } + for fid in &builder.processed_functions.clone() { + let inner_func = &mut env.functions[fid.0 as usize]; + for (block_id, block) in inner_func.body.blocks.iter_mut() { + if let Some(phis) = builder.pending_phis.remove(block_id) { + block.phis.extend(phis); + } + } + } +} + +fn enter_ssa_impl( + func: &mut HirFunction, + builder: &mut SSABuilder, + env: &mut Environment, + root_entry: BlockId, +) -> Result<(), CompilerDiagnostic> { + let mut visited_blocks: HashSet = HashSet::new(); + let block_ids: Vec = func.body.blocks.keys().copied().collect(); + + for block_id in &block_ids { + let block_id = *block_id; + + if visited_blocks.contains(&block_id) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("found a cycle! visiting bb{} again", block_id.0), + None, + )); + } + + visited_blocks.insert(block_id); + builder.start_block(block_id); + + // Handle params at the root entry + if block_id == root_entry { + if !func.context.is_empty() { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function context to be empty for outer function declarations", + None, + )); + } + let params = std::mem::take(&mut func.params); + let mut new_params = Vec::with_capacity(params.len()); + for param in params { + new_params.push(match param { + ParamPattern::Place(p) => ParamPattern::Place(builder.define_place(&p, env)?), + ParamPattern::Spread(s) => ParamPattern::Spread(SpreadPattern { + place: builder.define_place(&s.place, env)?, + }), + }); + } + func.params = new_params; + } + + // Process instructions + let instruction_ids: Vec = func + .body + .blocks + .get(&block_id) + .unwrap() + .instructions + .clone(); + + for instr_id in &instruction_ids { + let instr_idx = instr_id.0 as usize; + let instr = &mut func.instructions[instr_idx]; + + // For FunctionExpression/ObjectMethod, we need to handle context + // mapping specially because env.functions is borrowed by the closure. + // First, check if this is a FunctionExpression/ObjectMethod and handle + // context mapping separately. + let func_expr_id = match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => Some(lowered_func.func), + _ => None, + }; + + // Map context places for function expressions before other operands + if let Some(fid) = func_expr_id { + let context = std::mem::take(&mut env.functions[fid.0 as usize].context); + env.functions[fid.0 as usize].context = context + .into_iter() + .map(|place| builder.get_place(&place, env)) + .collect(); + } + + // Map non-context operands + visitors::for_each_instruction_value_operand_mut(&mut instr.value, &mut |place| { + *place = builder.get_place(place, env); + }); + + // Map lvalues (skip DeclareContext/StoreContext — context variables + // don't participate in SSA renaming) + let instr = &mut func.instructions[instr_idx]; + let mut lvalue_err: Option = None; + visitors::for_each_instruction_lvalue_mut(instr, &mut |place| { + if lvalue_err.is_none() { + match builder.define_place(place, env) { + Ok(new_place) => *place = new_place, + Err(e) => lvalue_err = Some(e), + } + } + }); + if let Some(e) = lvalue_err { + return Err(e); + } + + // Handle inner function SSA + if let Some(fid) = func_expr_id { + builder.processed_functions.push(fid); + let inner_func = &mut env.functions[fid.0 as usize]; + let inner_entry = inner_func.body.entry; + let entry_block = inner_func.body.blocks.get_mut(&inner_entry).unwrap(); + + if !entry_block.preds.is_empty() { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected function expression entry block to have zero predecessors", + None, + )); + } + entry_block.preds.insert(block_id); + + builder.define_function(inner_func); + + let saved_current = builder.current; + + // Map inner function params + let inner_params = std::mem::take(&mut env.functions[fid.0 as usize].params); + let mut new_inner_params = Vec::with_capacity(inner_params.len()); + for param in inner_params { + new_inner_params.push(match param { + ParamPattern::Place(p) => { + ParamPattern::Place(builder.define_place(&p, env)?) + } + ParamPattern::Spread(s) => ParamPattern::Spread(SpreadPattern { + place: builder.define_place(&s.place, env)?, + }), + }); + } + env.functions[fid.0 as usize].params = new_inner_params; + + // Take the inner function out of the arena to process it + let mut inner_func = + std::mem::replace(&mut env.functions[fid.0 as usize], placeholder_function()); + + enter_ssa_impl(&mut inner_func, builder, env, root_entry)?; + + // Put it back + env.functions[fid.0 as usize] = inner_func; + + builder.current = saved_current; + + // Clear entry preds + env.functions[fid.0 as usize] + .body + .blocks + .get_mut(&inner_entry) + .unwrap() + .preds + .clear(); + builder.block_preds.insert(inner_entry, Vec::new()); + } + } + + // Map terminal operands + let terminal = &mut func.body.blocks.get_mut(&block_id).unwrap().terminal; + visitors::for_each_terminal_operand_mut(terminal, &mut |place| { + *place = builder.get_place(place, env); + }); + + // Handle successors + let terminal_ref = &func.body.blocks.get(&block_id).unwrap().terminal; + let successors = visitors::each_terminal_successor(terminal_ref); + for output_id in successors { + let output_preds_len = builder + .block_preds + .get(&output_id) + .map(|p| p.len() as u32) + .unwrap_or(0); + + let count = if builder.unsealed_preds.contains_key(&output_id) { + builder.unsealed_preds[&output_id] - 1 + } else { + output_preds_len - 1 + }; + builder.unsealed_preds.insert(output_id, count); + + if count == 0 && visited_blocks.contains(&output_id) { + builder.fix_incomplete_phis(output_id, env); + } + } + } + + Ok(()) +} + +/// Create a placeholder HirFunction for temporarily swapping an inner function +/// out of `env.functions` via `std::mem::replace`. The placeholder is never +/// read — the real function is swapped back immediately after processing. +pub fn placeholder_function() -> HirFunction { + HirFunction { + loc: None, + id: None, + name_hint: None, + fn_type: ReactFunctionType::Other, + params: Vec::new(), + return_type_annotation: None, + returns: Place { + identifier: IdentifierId(0), + effect: Effect::Unknown, + reactive: false, + loc: None, + }, + context: Vec::new(), + body: HIR { + entry: BlockId(0), + blocks: IndexMap::new(), + }, + instructions: Vec::new(), + generator: false, + is_async: false, + directives: Vec::new(), + aliasing_effects: None, + } +} diff --git a/crates/react_compiler_ssa/src/lib.rs b/crates/react_compiler_ssa/src/lib.rs new file mode 100644 index 000000000000..e35b18dd1d61 --- /dev/null +++ b/crates/react_compiler_ssa/src/lib.rs @@ -0,0 +1,11 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +mod eliminate_redundant_phi; +pub mod enter_ssa; +mod rewrite_instruction_kinds_based_on_reassignment; + +pub use eliminate_redundant_phi::eliminate_redundant_phi; +pub use enter_ssa::enter_ssa; +pub use rewrite_instruction_kinds_based_on_reassignment::rewrite_instruction_kinds_based_on_reassignment; diff --git a/crates/react_compiler_ssa/src/rewrite_instruction_kinds_based_on_reassignment.rs b/crates/react_compiler_ssa/src/rewrite_instruction_kinds_based_on_reassignment.rs new file mode 100644 index 000000000000..c485e951e3fb --- /dev/null +++ b/crates/react_compiler_ssa/src/rewrite_instruction_kinds_based_on_reassignment.rs @@ -0,0 +1,397 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Rewrites InstructionKind of instructions which declare/assign variables, +//! converting the first declaration to Const/Let depending on whether it is +//! subsequently reassigned, and ensuring that subsequent reassignments are +//! marked as Reassign. +//! +//! Ported from TypeScript +//! `src/SSA/RewriteInstructionKindsBasedOnReassignment.ts`. +//! +//! Note that declarations which were const in the original program cannot +//! become `let`, but the inverse is not true: a `let` which was reassigned in +//! the source may be converted to a `const` if the reassignment is not used and +//! was removed by dead code elimination. + +use std::collections::HashMap; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, visitors::each_pattern_operand, BlockKind, DeclarationId, + HirFunction, InstructionKind, InstructionValue, ParamPattern, Place, +}; + +/// Create an invariant CompilerError (matches TS CompilerError.invariant). +/// When a loc is provided, creates a CompilerDiagnostic with an error detail +/// item (matching TS CompilerError.invariant which uses .withDetails()). +fn invariant_error(reason: &str, description: Option) -> CompilerError { + invariant_error_with_loc(reason, description, None) +} + +fn invariant_error_with_loc( + reason: &str, + description: Option, + loc: Option, +) -> CompilerError { + let mut err = CompilerError::new(); + let diagnostic = CompilerDiagnostic::new(ErrorCategory::Invariant, reason, description) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some(reason.to_string()), + identifier_name: None, + }); + err.push_diagnostic(diagnostic); + err +} + +/// Format an InstructionKind variant name (matches TS `${kind}` interpolation). +fn format_kind(kind: Option) -> String { + match kind { + Some(InstructionKind::Const) => "Const".to_string(), + Some(InstructionKind::Let) => "Let".to_string(), + Some(InstructionKind::Reassign) => "Reassign".to_string(), + Some(InstructionKind::Catch) => "Catch".to_string(), + Some(InstructionKind::HoistedConst) => "HoistedConst".to_string(), + Some(InstructionKind::HoistedLet) => "HoistedLet".to_string(), + Some(InstructionKind::HoistedFunction) => "HoistedFunction".to_string(), + Some(InstructionKind::Function) => "Function".to_string(), + None => "null".to_string(), + } +} + +/// Format a Place like TS `printPlace()`: ` +/// $[]{reactive}` +fn format_place(place: &Place, env: &Environment) -> String { + let ident = &env.identifiers[place.identifier.0 as usize]; + let name = match &ident.name { + Some(n) => n.value().to_string(), + None => String::new(), + }; + let scope = match ident.scope { + Some(scope_id) => format!("_@{}", scope_id.0), + None => String::new(), + }; + let mutable_range = if ident.mutable_range.end.0 > ident.mutable_range.start.0 + 1 { + format!( + "[{}:{}]", + ident.mutable_range.start.0, ident.mutable_range.end.0 + ) + } else { + String::new() + }; + let reactive = if place.reactive { "{reactive}" } else { "" }; + format!( + "{} {}${}{}{}{}", + place.effect, name, place.identifier.0, scope, mutable_range, reactive + ) +} + +/// Index into a collected list of declaration mutations to apply. +/// +/// We use a two-phase approach: first collect which declarations exist, +/// then apply mutations. This is because in the TS code, `declarations` +/// map stores references to LValue/LValuePattern and mutates `kind` through +/// them. In Rust, we track instruction indices and apply changes in a second +/// pass. +enum DeclarationLoc { + /// An LValue from DeclareLocal or StoreLocal — identified by (block_index, + /// instr_index_in_block) + Instruction { + block_index: usize, + instr_local_index: usize, + }, + /// A parameter or context variable (seeded as Let, may be upgraded to Let + /// on reassignment — already Let) + ParamOrContext, +} + +pub fn rewrite_instruction_kinds_based_on_reassignment( + func: &mut HirFunction, + env: &Environment, +) -> Result<(), CompilerError> { + // Phase 1: Collect all information about which declarations need updates. + // + // Track: for each DeclarationId, the location of its first declaration, + // and whether it needs to be changed to Let (because of reassignment). + let mut declarations: HashMap = HashMap::new(); + // Track which (block_index, instr_local_index) should have their lvalue.kind + // set to Reassign + let mut reassign_locs: Vec<(usize, usize)> = Vec::new(); + // Track which declaration locations need to be set to Let + let mut let_locs: Vec<(usize, usize)> = Vec::new(); + // Track which (block_index, instr_local_index) should have their lvalue.kind + // set to Const + let mut const_locs: Vec<(usize, usize)> = Vec::new(); + // Track which (block_index, instr_local_index) Destructure instructions get a + // specific kind + let mut destructure_kind_locs: Vec<(usize, usize, InstructionKind)> = Vec::new(); + + // Seed with parameters + for param in &func.params { + let place: &Place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let ident = &env.identifiers[place.identifier.0 as usize]; + if ident.name.is_some() { + declarations.insert(ident.declaration_id, DeclarationLoc::ParamOrContext); + } + } + + // Seed with context variables + for place in &func.context { + let ident = &env.identifiers[place.identifier.0 as usize]; + if ident.name.is_some() { + declarations.insert(ident.declaration_id, DeclarationLoc::ParamOrContext); + } + } + + // Process all blocks + let block_keys: Vec<_> = func.body.blocks.keys().cloned().collect(); + for (block_index, block_id) in block_keys.iter().enumerate() { + let block = &func.body.blocks[block_id]; + let block_kind = block.kind; + for (local_idx, instr_id) in block.instructions.iter().enumerate() { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::DeclareLocal { lvalue, .. } => { + let decl_id = + env.identifiers[lvalue.place.identifier.0 as usize].declaration_id; + if declarations.contains_key(&decl_id) { + return Err(invariant_error_with_loc( + "Expected variable not to be defined prior to declaration", + Some(format!( + "{} was already defined", + format_place(&lvalue.place, env), + )), + lvalue.place.loc, + )); + } + declarations.insert( + decl_id, + DeclarationLoc::Instruction { + block_index, + instr_local_index: local_idx, + }, + ); + } + InstructionValue::StoreLocal { lvalue, .. } => { + let ident = &env.identifiers[lvalue.place.identifier.0 as usize]; + if ident.name.is_some() { + let decl_id = ident.declaration_id; + if let Some(existing) = declarations.get(&decl_id) { + // Reassignment: mark existing declaration as Let, current as Reassign + match existing { + DeclarationLoc::Instruction { + block_index: bi, + instr_local_index: ili, + } => { + let_locs.push((*bi, *ili)); + } + DeclarationLoc::ParamOrContext => { + // Already Let, no-op + } + } + reassign_locs.push((block_index, local_idx)); + } else { + // First store — mark as Const + // Mirrors TS: CompilerError.invariant(!declarations.has(...)) + if declarations.contains_key(&decl_id) { + return Err(invariant_error_with_loc( + "Expected variable not to be defined prior to declaration", + Some(format!( + "{} was already defined", + format_place(&lvalue.place, env), + )), + lvalue.place.loc, + )); + } + declarations.insert( + decl_id, + DeclarationLoc::Instruction { + block_index, + instr_local_index: local_idx, + }, + ); + const_locs.push((block_index, local_idx)); + } + } + } + InstructionValue::Destructure { lvalue, .. } => { + let mut kind: Option = None; + for place in each_pattern_operand(&lvalue.pattern) { + let ident = &env.identifiers[place.identifier.0 as usize]; + if ident.name.is_none() { + if !(kind.is_none() || kind == Some(InstructionKind::Const)) { + return Err(invariant_error_with_loc( + "Expected consistent kind for destructuring", + Some(format!( + "other places were `{}` but '{}' is const", + format_kind(kind), + format_place(&place, env), + )), + place.loc, + )); + } + kind = Some(InstructionKind::Const); + } else { + let decl_id = ident.declaration_id; + if let Some(existing) = declarations.get(&decl_id) { + // Reassignment + if !(kind.is_none() || kind == Some(InstructionKind::Reassign)) { + return Err(invariant_error_with_loc( + "Expected consistent kind for destructuring", + Some(format!( + "Other places were `{}` but '{}' is reassigned", + format_kind(kind), + format_place(&place, env), + )), + place.loc, + )); + } + kind = Some(InstructionKind::Reassign); + match existing { + DeclarationLoc::Instruction { + block_index: bi, + instr_local_index: ili, + } => { + let_locs.push((*bi, *ili)); + } + DeclarationLoc::ParamOrContext => { + // Already Let + } + } + } else { + // New declaration + if block_kind == BlockKind::Value { + return Err(invariant_error_with_loc( + "TODO: Handle reassignment in a value block where the \ + original declaration was removed by dead code \ + elimination (DCE)", + None, + place.loc, + )); + } + declarations.insert( + decl_id, + DeclarationLoc::Instruction { + block_index, + instr_local_index: local_idx, + }, + ); + if !(kind.is_none() || kind == Some(InstructionKind::Const)) { + return Err(invariant_error_with_loc( + "Expected consistent kind for destructuring", + Some(format!( + "Other places were `{}` but '{}' is const", + format_kind(kind), + format_place(&place, env), + )), + place.loc, + )); + } + kind = Some(InstructionKind::Const); + } + } + } + let kind = + kind.ok_or_else(|| invariant_error("Expected at least one operand", None))?; + destructure_kind_locs.push((block_index, local_idx, kind)); + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + let ident = &env.identifiers[lvalue.identifier.0 as usize]; + let decl_id = ident.declaration_id; + let Some(existing) = declarations.get(&decl_id) else { + return Err(invariant_error_with_loc( + "Expected variable to have been defined", + Some(format!("No declaration for {}", format_place(lvalue, env),)), + lvalue.loc, + )); + }; + match existing { + DeclarationLoc::Instruction { + block_index: bi, + instr_local_index: ili, + } => { + let_locs.push((*bi, *ili)); + } + DeclarationLoc::ParamOrContext => { + // Already Let + } + } + } + _ => {} + } + } + } + + // Phase 2: Apply all collected mutations. + + // Helper: given (block_index, instr_local_index), get the InstructionId + // and mutate the instruction's lvalue kind. + for (bi, ili) in const_locs { + let block_id = &block_keys[bi]; + let instr_id = func.body.blocks[block_id].instructions[ili]; + let instr = &mut func.instructions[instr_id.0 as usize]; + match &mut instr.value { + InstructionValue::StoreLocal { lvalue, .. } => { + lvalue.kind = InstructionKind::Const; + } + _ => {} + } + } + + for (bi, ili) in reassign_locs { + let block_id = &block_keys[bi]; + let instr_id = func.body.blocks[block_id].instructions[ili]; + let instr = &mut func.instructions[instr_id.0 as usize]; + match &mut instr.value { + InstructionValue::StoreLocal { lvalue, .. } => { + lvalue.kind = InstructionKind::Reassign; + } + _ => {} + } + } + + // Apply destructure_kind_locs BEFORE let_locs: a Destructure that first + // declares a variable gets kind=Const here, but if a later instruction + // reassigns that variable the Destructure must become Let. Applying + // let_locs afterwards allows it to override the Const set here, matching + // the TS behaviour where `declaration.kind = Let` mutates the original + // lvalue reference after the Destructure's own `lvalue.kind = kind`. + for (bi, ili, kind) in destructure_kind_locs { + let block_id = &block_keys[bi]; + let instr_id = func.body.blocks[block_id].instructions[ili]; + let instr = &mut func.instructions[instr_id.0 as usize]; + match &mut instr.value { + InstructionValue::Destructure { lvalue, .. } => { + lvalue.kind = kind; + } + _ => {} + } + } + + for (bi, ili) in let_locs { + let block_id = &block_keys[bi]; + let instr_id = func.body.blocks[block_id].instructions[ili]; + let instr = &mut func.instructions[instr_id.0 as usize]; + match &mut instr.value { + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::StoreLocal { lvalue, .. } => { + lvalue.kind = InstructionKind::Let; + } + InstructionValue::Destructure { lvalue, .. } => { + lvalue.kind = InstructionKind::Let; + } + _ => {} + } + } + + Ok(()) +} diff --git a/crates/react_compiler_typeinference/Cargo.toml b/crates/react_compiler_typeinference/Cargo.toml new file mode 100644 index 000000000000..670d42c186ba --- /dev/null +++ b/crates/react_compiler_typeinference/Cargo.toml @@ -0,0 +1,12 @@ +[package] +description = "Vendored React Compiler type inference from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_typeinference" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } +react_compiler_ssa = { path = "../react_compiler_ssa" } diff --git a/crates/react_compiler_typeinference/src/infer_types.rs b/crates/react_compiler_typeinference/src/infer_types.rs new file mode 100644 index 000000000000..435aac1fe514 --- /dev/null +++ b/crates/react_compiler_typeinference/src/infer_types.rs @@ -0,0 +1,1648 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Type inference pass. +//! +//! Generates type equations from the HIR, unifies them, and applies the +//! resolved types back to identifiers. Analogous to TS `InferTypes.ts`. + +use std::collections::HashMap; + +use react_compiler_diagnostics::{CompilerDiagnostic, ErrorCategory}; +use react_compiler_hir::{ + environment::{is_hook_name, Environment}, + object_shape::{ + ShapeRegistry, BUILT_IN_ARRAY_ID, BUILT_IN_FUNCTION_ID, BUILT_IN_JSX_ID, + BUILT_IN_MIXED_READONLY_ID, BUILT_IN_OBJECT_ID, BUILT_IN_PROPS_ID, BUILT_IN_REF_VALUE_ID, + BUILT_IN_SET_STATE_ID, BUILT_IN_USE_REF_ID, + }, + ArrayPatternElement, BinaryOperator, FunctionId, HirFunction, Identifier, IdentifierId, + IdentifierName, InstructionId, InstructionKind, InstructionValue, JsxAttribute, + LoweredFunction, ManualMemoDependencyRoot, NonLocalBinding, ObjectPropertyKey, + ObjectPropertyOrSpread, ParamPattern, Pattern, PropertyLiteral, PropertyNameKind, + ReactFunctionType, SourceLocation, Terminal, Type, TypeId, +}; +use react_compiler_ssa::enter_ssa::placeholder_function; + +// ============================================================================= +// Public API +// ============================================================================= + +pub fn infer_types( + func: &mut HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + let enable_treat_ref_like_identifiers_as_refs = + env.config.enable_treat_ref_like_identifiers_as_refs; + let enable_treat_set_identifiers_as_state_setters = + env.config.enable_treat_set_identifiers_as_state_setters; + // Pre-compute custom hook type for property resolution fallback + let custom_hook_type = env.get_custom_hook_type_opt(); + let mut unifier = Unifier::new( + enable_treat_ref_like_identifiers_as_refs, + custom_hook_type, + enable_treat_set_identifiers_as_state_setters, + ); + generate(func, env, &mut unifier)?; + + apply_function( + func, + &env.functions, + &mut env.identifiers, + &mut env.types, + &mut unifier, + ); + Ok(()) +} + +// ============================================================================= +// Helpers +// ============================================================================= + +/// Get the type for an identifier as a TypeVar referencing its type slot. +fn get_type(id: IdentifierId, identifiers: &[Identifier]) -> Type { + let type_id = identifiers[id.0 as usize].type_; + Type::TypeVar { id: type_id } +} + +/// Allocate a new TypeVar in the types arena (standalone, no &mut Environment +/// needed). +fn make_type(types: &mut Vec) -> Type { + let id = TypeId(types.len() as u32); + types.push(Type::TypeVar { id }); + Type::TypeVar { id } +} + +/// Pre-resolve LoadGlobal types for a single function's instructions. +fn pre_resolve_globals( + func: &HirFunction, + function_key: u32, + env: &mut Environment, + global_types: &mut HashMap<(u32, InstructionId), Type>, +) { + for &instr_id in func.body.blocks.values().flat_map(|b| &b.instructions) { + let instr = &func.instructions[instr_id.0 as usize]; + if let InstructionValue::LoadGlobal { binding, loc, .. } = &instr.value { + if let Some(global_type) = env.get_global_declaration(binding, *loc).ok().flatten() { + global_types.insert((function_key, instr_id), global_type); + } + } + } +} + +/// Recursively pre-resolve LoadGlobal types for an inner function and its +/// children. +fn pre_resolve_globals_recursive( + func_id: FunctionId, + env: &mut Environment, + global_types: &mut HashMap<(u32, InstructionId), Type>, +) { + // Collect LoadGlobal bindings and child function IDs in one pass to avoid + // borrow conflicts (we need &env.functions to read, then &mut env for + // get_global_declaration). + let inner = &env.functions[func_id.0 as usize]; + let mut load_globals: Vec<(InstructionId, NonLocalBinding, Option)> = + Vec::new(); + let mut child_func_ids: Vec = Vec::new(); + + for block in inner.body.blocks.values() { + for &instr_id in &block.instructions { + let instr = &inner.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::LoadGlobal { binding, loc, .. } => { + load_globals.push((instr_id, binding.clone(), *loc)); + } + InstructionValue::FunctionExpression { + lowered_func: LoweredFunction { func: fid }, + .. + } + | InstructionValue::ObjectMethod { + lowered_func: LoweredFunction { func: fid }, + .. + } => { + child_func_ids.push(*fid); + } + _ => {} + } + } + } + + // Now resolve globals (no longer borrowing env.functions) + for (instr_id, binding, loc) in load_globals { + if let Some(global_type) = env.get_global_declaration(&binding, loc).ok().flatten() { + global_types.insert((func_id.0, instr_id), global_type); + } + } + + // Recurse into child functions + for child_id in child_func_ids { + pre_resolve_globals_recursive(child_id, env, global_types); + } +} + +fn is_primitive_binary_op(op: &BinaryOperator) -> bool { + matches!( + op, + BinaryOperator::Add + | BinaryOperator::Subtract + | BinaryOperator::Divide + | BinaryOperator::Modulo + | BinaryOperator::Multiply + | BinaryOperator::Exponent + | BinaryOperator::BitwiseAnd + | BinaryOperator::BitwiseOr + | BinaryOperator::ShiftRight + | BinaryOperator::ShiftLeft + | BinaryOperator::BitwiseXor + | BinaryOperator::GreaterThan + | BinaryOperator::LessThan + | BinaryOperator::GreaterEqual + | BinaryOperator::LessEqual + ) +} + +/// Resolve a property type from the shapes registry. +/// If `custom_hook_type` is provided and the property name looks like a hook, +/// it will be used as a fallback when no matching property is found (matching +/// TS `getPropertyType` behavior). +fn resolve_property_type( + shapes: &ShapeRegistry, + resolved_object: &Type, + property_name: &PropertyNameKind, + custom_hook_type: Option<&Type>, +) -> Option { + let shape_id = match resolved_object { + Type::Object { shape_id } | Type::Function { shape_id, .. } => shape_id.as_deref(), + _ => { + // No shape, but if property name is hook-like, return hook type + if let Some(hook_type) = custom_hook_type { + if let PropertyNameKind::Literal { + value: PropertyLiteral::String(s), + } = property_name + { + if is_hook_name(s) { + return Some(hook_type.clone()); + } + } + } + return None; + } + }; + let shape_id = match shape_id { + Some(id) => id, + None => { + // Object/Function with no shapeId: TS getPropertyType falls through + // to hook-name check, TS getFallthroughPropertyType returns null + if let PropertyNameKind::Literal { + value: PropertyLiteral::String(s), + } = property_name + { + if is_hook_name(s) { + return custom_hook_type.cloned(); + } + } + return None; + } + }; + let shape = shapes.get(shape_id)?; + + match property_name { + PropertyNameKind::Literal { value } => match value { + PropertyLiteral::String(s) => shape + .properties + .get(s.as_str()) + .or_else(|| shape.properties.get("*")) + .cloned() + // Hook-name fallback: if property is not found in shape but looks + // like a hook name, return the custom hook type + .or_else(|| { + if is_hook_name(s) { + custom_hook_type.cloned() + } else { + None + } + }), + PropertyLiteral::Number(_) => shape.properties.get("*").cloned(), + }, + PropertyNameKind::Computed { .. } => shape.properties.get("*").cloned(), + } +} + +/// Check if a property access looks like a ref pattern (e.g. `ref.current`, +/// `fooRef.current`). Matches TS `isRefLikeName` in InferTypes.ts. +fn is_ref_like_name(object_name: &str, property_name: &PropertyNameKind) -> bool { + let is_current = match property_name { + PropertyNameKind::Literal { + value: PropertyLiteral::String(s), + } => s == "current", + _ => false, + }; + if !is_current { + return false; + } + // Match TS regex: /^(?:[a-zA-Z$_][a-zA-Z$_0-9]*)Ref$|^ref$/ + // "Ref" alone does NOT match — requires at least one character before "Ref" + // (e.g., "fooRef", "aRef" match, but bare "Ref" does not). + object_name == "ref" + || (object_name.len() > 3 + && object_name.ends_with("Ref") + && object_name[..1] + .chars() + .next() + .is_some_and(|c| c.is_ascii_alphabetic() || c == '$' || c == '_')) +} + +/// Type equality matching TS `typeEquals`. +/// +/// Note: Function equality only compares return types (matching TS +/// `funcTypeEquals` which ignores `shapeId` and `isConstructor`). Phi equality +/// always returns false because the TS `phiTypeEquals` has a bug where `return +/// false` is outside the `if` block, so it unconditionally returns false. +fn type_equals(a: &Type, b: &Type) -> bool { + match (a, b) { + (Type::TypeVar { id: id_a }, Type::TypeVar { id: id_b }) => id_a == id_b, + (Type::Primitive, Type::Primitive) => true, + (Type::Poly, Type::Poly) => true, + (Type::ObjectMethod, Type::ObjectMethod) => true, + (Type::Object { shape_id: sa }, Type::Object { shape_id: sb }) => sa == sb, + ( + Type::Function { + return_type: ra, .. + }, + Type::Function { + return_type: rb, .. + }, + ) => type_equals(ra, rb), + _ => false, + } +} + +fn set_name(names: &mut HashMap, id: IdentifierId, source: &Identifier) { + if let Some(IdentifierName::Named(ref name)) = source.name { + names.insert(id, name.clone()); + } +} + +fn get_name(names: &HashMap, id: IdentifierId) -> String { + names.get(&id).cloned().unwrap_or_default() +} + +// ============================================================================= +// Generate equations +// ============================================================================= + +/// Generate type equations from a top-level function. +/// +/// Takes `&mut Environment` for convenience. Inner functions use +/// `generate_for_function_id` with split borrows instead, because the +/// take/replace pattern on `env.functions` requires separate `&mut` access +/// to different fields. +fn generate( + func: &HirFunction, + env: &mut Environment, + unifier: &mut Unifier, +) -> Result<(), CompilerDiagnostic> { + // Component params + if func.fn_type == ReactFunctionType::Component { + if let Some(first) = func.params.first() { + if let ParamPattern::Place(place) = first { + let ty = get_type(place.identifier, &env.identifiers); + unifier.unify( + ty, + Type::Object { + shape_id: Some(BUILT_IN_PROPS_ID.to_string()), + }, + &env.shapes, + )?; + } + } + if let Some(second) = func.params.get(1) { + if let ParamPattern::Place(place) = second { + let ty = get_type(place.identifier, &env.identifiers); + unifier.unify( + ty, + Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + &env.shapes, + )?; + } + } + } + + // Pre-resolve LoadGlobal types for all functions (outer + inner). We do + // this before the instruction loop because get_global_declaration needs + // &mut env, but generate_instruction_types takes split borrows on env fields. + // The key is (function_key, InstructionId) where function_key is u32::MAX + // for the outer function and FunctionId.0 for inner functions. + let mut global_types: HashMap<(u32, InstructionId), Type> = HashMap::new(); + pre_resolve_globals(func, u32::MAX, env, &mut global_types); + // Also pre-resolve inner functions recursively + for &instr_id in func.body.blocks.values().flat_map(|b| &b.instructions) { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::FunctionExpression { + lowered_func: LoweredFunction { func: func_id }, + .. + } + | InstructionValue::ObjectMethod { + lowered_func: LoweredFunction { func: func_id }, + .. + } => { + pre_resolve_globals_recursive(*func_id, env, &mut global_types); + } + _ => {} + } + } + + let mut names: HashMap = HashMap::new(); + let mut return_types: Vec = Vec::new(); + + for (_block_id, block) in &func.body.blocks { + // Phis + for phi in &block.phis { + let left = get_type(phi.place.identifier, &env.identifiers); + let operands: Vec = phi + .operands + .values() + .map(|p| get_type(p.identifier, &env.identifiers)) + .collect(); + unifier.unify(left, Type::Phi { operands }, &env.shapes)?; + } + + // Instructions — use split borrows: &env.identifiers, &env.shapes + // are immutable, while &mut env.types and &mut env.functions are mutable. + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + generate_instruction_types( + instr, + instr_id, + u32::MAX, + &env.identifiers, + &mut env.types, + &mut env.functions, + &mut names, + &global_types, + &env.shapes, + unifier, + )?; + } + + // Return terminals + if let Terminal::Return { ref value, .. } = block.terminal { + return_types.push(get_type(value.identifier, &env.identifiers)); + } + } + + // Unify return types + let returns_type = get_type(func.returns.identifier, &env.identifiers); + if return_types.len() > 1 { + unifier.unify( + returns_type, + Type::Phi { + operands: return_types, + }, + &env.shapes, + )?; + } else if return_types.len() == 1 { + unifier.unify( + returns_type, + return_types.into_iter().next().unwrap(), + &env.shapes, + )?; + } + Ok(()) +} + +/// Recursively generate equations for an inner function (accessed via +/// FunctionId). +fn generate_for_function_id( + func_id: FunctionId, + identifiers: &[Identifier], + types: &mut Vec, + functions: &mut Vec, + global_types: &HashMap<(u32, InstructionId), Type>, + shapes: &ShapeRegistry, + unifier: &mut Unifier, +) -> Result<(), CompilerDiagnostic> { + // Take the function out temporarily to avoid borrow conflicts + let inner = std::mem::replace(&mut functions[func_id.0 as usize], placeholder_function()); + + // Process params for component inner functions + if inner.fn_type == ReactFunctionType::Component { + if let Some(first) = inner.params.first() { + if let ParamPattern::Place(place) = first { + let ty = get_type(place.identifier, identifiers); + unifier.unify( + ty, + Type::Object { + shape_id: Some(BUILT_IN_PROPS_ID.to_string()), + }, + shapes, + )?; + } + } + if let Some(second) = inner.params.get(1) { + if let ParamPattern::Place(place) = second { + let ty = get_type(place.identifier, identifiers); + unifier.unify( + ty, + Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + shapes, + )?; + } + } + } + + // TS creates a fresh `names` Map per recursive `generate` call, so inner + // functions don't inherit or pollute the outer function's name mappings. + let mut inner_names: HashMap = HashMap::new(); + let mut inner_return_types: Vec = Vec::new(); + + for (_block_id, block) in &inner.body.blocks { + for phi in &block.phis { + let left = get_type(phi.place.identifier, identifiers); + let operands: Vec = phi + .operands + .values() + .map(|p| get_type(p.identifier, identifiers)) + .collect(); + unifier.unify(left, Type::Phi { operands }, shapes)?; + } + + for &instr_id in &block.instructions { + let instr = &inner.instructions[instr_id.0 as usize]; + generate_instruction_types( + instr, + instr_id, + func_id.0, + identifiers, + types, + functions, + &mut inner_names, + global_types, + shapes, + unifier, + )?; + } + + if let Terminal::Return { ref value, .. } = block.terminal { + inner_return_types.push(get_type(value.identifier, identifiers)); + } + } + + let returns_type = get_type(inner.returns.identifier, identifiers); + if inner_return_types.len() > 1 { + unifier.unify( + returns_type, + Type::Phi { + operands: inner_return_types, + }, + shapes, + )?; + } else if inner_return_types.len() == 1 { + unifier.unify( + returns_type, + inner_return_types.into_iter().next().unwrap(), + shapes, + )?; + } + + // Put the function back + functions[func_id.0 as usize] = inner; + Ok(()) +} + +fn generate_instruction_types( + instr: &react_compiler_hir::Instruction, + instr_id: InstructionId, + function_key: u32, + identifiers: &[Identifier], + types: &mut Vec, + functions: &mut Vec, + names: &mut HashMap, + global_types: &HashMap<(u32, InstructionId), Type>, + shapes: &ShapeRegistry, + unifier: &mut Unifier, +) -> Result<(), CompilerDiagnostic> { + let left = get_type(instr.lvalue.identifier, identifiers); + + match &instr.value { + InstructionValue::TemplateLiteral { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::Primitive { .. } => { + unifier.unify(left, Type::Primitive, shapes)?; + } + + InstructionValue::UnaryExpression { .. } => { + unifier.unify(left, Type::Primitive, shapes)?; + } + + InstructionValue::LoadLocal { place, .. } => { + set_name( + names, + instr.lvalue.identifier, + &identifiers[place.identifier.0 as usize], + ); + let place_type = get_type(place.identifier, identifiers); + unifier.unify(left, place_type, shapes)?; + } + + InstructionValue::DeclareContext { .. } | InstructionValue::LoadContext { .. } => { + // Intentionally skip type inference for most context variables + } + + InstructionValue::StoreContext { lvalue, value, .. } => { + if lvalue.kind == InstructionKind::Const { + let lvalue_type = get_type(lvalue.place.identifier, identifiers); + let value_type = get_type(value.identifier, identifiers); + unifier.unify(lvalue_type, value_type, shapes)?; + } + } + + InstructionValue::StoreLocal { lvalue, value, .. } => { + let value_type = get_type(value.identifier, identifiers); + unifier.unify(left, value_type.clone(), shapes)?; + let lvalue_type = get_type(lvalue.place.identifier, identifiers); + unifier.unify(lvalue_type, value_type, shapes)?; + } + + InstructionValue::StoreGlobal { value, .. } => { + let value_type = get_type(value.identifier, identifiers); + unifier.unify(left, value_type, shapes)?; + } + + InstructionValue::BinaryExpression { + operator, + left: bin_left, + right: bin_right, + .. + } => { + if is_primitive_binary_op(operator) { + let left_operand_type = get_type(bin_left.identifier, identifiers); + unifier.unify(left_operand_type, Type::Primitive, shapes)?; + let right_operand_type = get_type(bin_right.identifier, identifiers); + unifier.unify(right_operand_type, Type::Primitive, shapes)?; + } + unifier.unify(left, Type::Primitive, shapes)?; + } + + InstructionValue::PostfixUpdate { value, lvalue, .. } + | InstructionValue::PrefixUpdate { value, lvalue, .. } => { + let value_type = get_type(value.identifier, identifiers); + unifier.unify(value_type, Type::Primitive, shapes)?; + let lvalue_type = get_type(lvalue.identifier, identifiers); + unifier.unify(lvalue_type, Type::Primitive, shapes)?; + unifier.unify(left, Type::Primitive, shapes)?; + } + + InstructionValue::LoadGlobal { .. } => { + // Type was pre-resolved in generate() via env.get_global_declaration() + if let Some(global_type) = global_types.get(&(function_key, instr_id)) { + unifier.unify(left, global_type.clone(), shapes)?; + } + } + + InstructionValue::CallExpression { callee, .. } => { + let return_type = make_type(types); + let mut shape_id = None; + if unifier.enable_treat_set_identifiers_as_state_setters { + let name = get_name(names, callee.identifier); + if name.starts_with("set") { + shape_id = Some(BUILT_IN_SET_STATE_ID.to_string()); + } + } + let callee_type = get_type(callee.identifier, identifiers); + unifier.unify( + callee_type, + Type::Function { + shape_id, + return_type: Box::new(return_type.clone()), + is_constructor: false, + }, + shapes, + )?; + unifier.unify(left, return_type, shapes)?; + } + + InstructionValue::TaggedTemplateExpression { tag, .. } => { + let return_type = make_type(types); + let tag_type = get_type(tag.identifier, identifiers); + unifier.unify( + tag_type, + Type::Function { + shape_id: None, + return_type: Box::new(return_type.clone()), + is_constructor: false, + }, + shapes, + )?; + unifier.unify(left, return_type, shapes)?; + } + + InstructionValue::ObjectExpression { properties, .. } => { + for prop in properties { + if let ObjectPropertyOrSpread::Property(obj_prop) = prop { + if let ObjectPropertyKey::Computed { name } = &obj_prop.key { + let name_type = get_type(name.identifier, identifiers); + unifier.unify(name_type, Type::Primitive, shapes)?; + } + } + } + unifier.unify( + left, + Type::Object { + shape_id: Some(BUILT_IN_OBJECT_ID.to_string()), + }, + shapes, + )?; + } + + InstructionValue::ArrayExpression { .. } => { + unifier.unify( + left, + Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + shapes, + )?; + } + + InstructionValue::PropertyLoad { + object, property, .. + } => { + let object_type = get_type(object.identifier, identifiers); + let object_name = get_name(names, object.identifier); + unifier.unify( + left, + Type::Property { + object_type: Box::new(object_type), + object_name, + property_name: PropertyNameKind::Literal { + value: property.clone(), + }, + }, + shapes, + )?; + } + + InstructionValue::ComputedLoad { + object, property, .. + } => { + let object_type = get_type(object.identifier, identifiers); + let object_name = get_name(names, object.identifier); + let prop_type = get_type(property.identifier, identifiers); + unifier.unify( + left, + Type::Property { + object_type: Box::new(object_type), + object_name, + property_name: PropertyNameKind::Computed { + value: Box::new(prop_type), + }, + }, + shapes, + )?; + } + + InstructionValue::MethodCall { property, .. } => { + let return_type = make_type(types); + let prop_type = get_type(property.identifier, identifiers); + unifier.unify( + prop_type, + Type::Function { + return_type: Box::new(return_type.clone()), + shape_id: None, + is_constructor: false, + }, + shapes, + )?; + unifier.unify(left, return_type, shapes)?; + } + + InstructionValue::Destructure { lvalue, value, .. } => match &lvalue.pattern { + Pattern::Array(array_pattern) => { + for (i, item) in array_pattern.items.iter().enumerate() { + match item { + ArrayPatternElement::Place(place) => { + let item_type = get_type(place.identifier, identifiers); + let value_type = get_type(value.identifier, identifiers); + let object_name = get_name(names, value.identifier); + unifier.unify( + item_type, + Type::Property { + object_type: Box::new(value_type), + object_name, + property_name: PropertyNameKind::Literal { + value: PropertyLiteral::String(i.to_string()), + }, + }, + shapes, + )?; + } + ArrayPatternElement::Spread(spread) => { + let spread_type = get_type(spread.place.identifier, identifiers); + unifier.unify( + spread_type, + Type::Object { + shape_id: Some(BUILT_IN_ARRAY_ID.to_string()), + }, + shapes, + )?; + } + ArrayPatternElement::Hole => { + continue; + } + } + } + } + Pattern::Object(object_pattern) => { + for prop in &object_pattern.properties { + if let ObjectPropertyOrSpread::Property(obj_prop) = prop { + match &obj_prop.key { + ObjectPropertyKey::Identifier { name } + | ObjectPropertyKey::String { name } => { + let prop_place_type = + get_type(obj_prop.place.identifier, identifiers); + let value_type = get_type(value.identifier, identifiers); + let object_name = get_name(names, value.identifier); + unifier.unify( + prop_place_type, + Type::Property { + object_type: Box::new(value_type), + object_name, + property_name: PropertyNameKind::Literal { + value: PropertyLiteral::String(name.clone()), + }, + }, + shapes, + )?; + } + _ => {} + } + } + } + } + }, + + InstructionValue::TypeCastExpression { value, .. } => { + let value_type = get_type(value.identifier, identifiers); + unifier.unify(left, value_type, shapes)?; + } + + InstructionValue::PropertyDelete { .. } | InstructionValue::ComputedDelete { .. } => { + unifier.unify(left, Type::Primitive, shapes)?; + } + + InstructionValue::FunctionExpression { + lowered_func: LoweredFunction { func: func_id }, + .. + } => { + // Recurse into inner function first + generate_for_function_id( + *func_id, + identifiers, + types, + functions, + global_types, + shapes, + unifier, + )?; + // Get the inner function's return type + let inner_func = &functions[func_id.0 as usize]; + let inner_return_type = get_type(inner_func.returns.identifier, identifiers); + unifier.unify( + left, + Type::Function { + shape_id: Some(BUILT_IN_FUNCTION_ID.to_string()), + return_type: Box::new(inner_return_type), + is_constructor: false, + }, + shapes, + )?; + } + + InstructionValue::NextPropertyOf { .. } => { + unifier.unify(left, Type::Primitive, shapes)?; + } + + InstructionValue::ObjectMethod { + lowered_func: LoweredFunction { func: func_id }, + .. + } => { + generate_for_function_id( + *func_id, + identifiers, + types, + functions, + global_types, + shapes, + unifier, + )?; + unifier.unify(left, Type::ObjectMethod, shapes)?; + } + + InstructionValue::JsxExpression { props, .. } => { + if unifier.enable_treat_ref_like_identifiers_as_refs { + for prop in props { + if let JsxAttribute::Attribute { name, place } = prop { + if name == "ref" { + let ref_type = get_type(place.identifier, identifiers); + unifier.unify( + ref_type, + Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + shapes, + )?; + } + } + } + } + unifier.unify( + left, + Type::Object { + shape_id: Some(BUILT_IN_JSX_ID.to_string()), + }, + shapes, + )?; + } + + InstructionValue::JsxFragment { .. } => { + unifier.unify( + left, + Type::Object { + shape_id: Some(BUILT_IN_JSX_ID.to_string()), + }, + shapes, + )?; + } + + InstructionValue::NewExpression { callee, .. } => { + let return_type = make_type(types); + let callee_type = get_type(callee.identifier, identifiers); + unifier.unify( + callee_type, + Type::Function { + return_type: Box::new(return_type.clone()), + shape_id: None, + is_constructor: true, + }, + shapes, + )?; + unifier.unify(left, return_type, shapes)?; + } + + InstructionValue::PropertyStore { + object, property, .. + } => { + let dummy = make_type(types); + let object_type = get_type(object.identifier, identifiers); + let object_name = get_name(names, object.identifier); + unifier.unify( + dummy, + Type::Property { + object_type: Box::new(object_type), + object_name, + property_name: PropertyNameKind::Literal { + value: property.clone(), + }, + }, + shapes, + )?; + } + + InstructionValue::DeclareLocal { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::ComputedStore { .. } + | InstructionValue::Await { .. } + | InstructionValue::GetIterator { .. } + | InstructionValue::IteratorNext { .. } + | InstructionValue::UnsupportedNode { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::FinishMemoize { .. } => { + // No type equations for these + } + + InstructionValue::StartMemoize { .. } => { + // No type equations for StartMemoize itself + } + } + Ok(()) +} + +// ============================================================================= +// Apply resolved types +// ============================================================================= + +fn apply_function( + func: &HirFunction, + functions: &[HirFunction], + identifiers: &mut [Identifier], + types: &mut Vec, + unifier: &Unifier, +) { + for (_block_id, block) in &func.body.blocks { + // Phi places + for phi in &block.phis { + resolve_identifier(phi.place.identifier, identifiers, types, unifier); + } + + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + + // Instruction lvalue + resolve_identifier(instr.lvalue.identifier, identifiers, types, unifier); + + // LValues from instruction values (StoreLocal, StoreContext, DeclareLocal, + // DeclareContext, Destructure) + apply_instruction_lvalues(&instr.value, identifiers, types, unifier); + + // Operands + apply_instruction_operands(&instr.value, identifiers, types, unifier); + + // Recurse into inner functions + match &instr.value { + InstructionValue::FunctionExpression { + lowered_func: LoweredFunction { func: func_id }, + .. + } + | InstructionValue::ObjectMethod { + lowered_func: LoweredFunction { func: func_id }, + .. + } => { + let inner_func = &functions[func_id.0 as usize]; + // Resolve types for captured context variable places (matching TS + // where eachInstructionValueOperand yields func.context places) + for ctx in &inner_func.context { + resolve_identifier(ctx.identifier, identifiers, types, unifier); + } + apply_function(inner_func, functions, identifiers, types, unifier); + } + _ => {} + } + } + } + + // Resolve return type + resolve_identifier(func.returns.identifier, identifiers, types, unifier); +} + +fn resolve_identifier( + id: IdentifierId, + identifiers: &mut [Identifier], + types: &mut Vec, + unifier: &Unifier, +) { + let type_id = identifiers[id.0 as usize].type_; + let current_type = types[type_id.0 as usize].clone(); + let resolved = unifier.get(¤t_type); + types[type_id.0 as usize] = resolved; +} + +/// Resolve types for instruction lvalues (mirrors TS eachInstructionLValue). +fn apply_instruction_lvalues( + value: &InstructionValue, + identifiers: &mut [Identifier], + types: &mut Vec, + unifier: &Unifier, +) { + match value { + InstructionValue::StoreLocal { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } => { + resolve_identifier(lvalue.place.identifier, identifiers, types, unifier); + } + InstructionValue::DeclareLocal { lvalue, .. } + | InstructionValue::DeclareContext { lvalue, .. } => { + resolve_identifier(lvalue.place.identifier, identifiers, types, unifier); + } + InstructionValue::Destructure { lvalue, .. } => match &lvalue.pattern { + Pattern::Array(array_pattern) => { + for item in &array_pattern.items { + match item { + ArrayPatternElement::Place(place) => { + resolve_identifier(place.identifier, identifiers, types, unifier); + } + ArrayPatternElement::Spread(spread) => { + resolve_identifier( + spread.place.identifier, + identifiers, + types, + unifier, + ); + } + ArrayPatternElement::Hole => {} + } + } + } + Pattern::Object(object_pattern) => { + for prop in &object_pattern.properties { + match prop { + ObjectPropertyOrSpread::Property(obj_prop) => { + resolve_identifier( + obj_prop.place.identifier, + identifiers, + types, + unifier, + ); + } + ObjectPropertyOrSpread::Spread(spread) => { + resolve_identifier( + spread.place.identifier, + identifiers, + types, + unifier, + ); + } + } + } + } + }, + _ => {} + } +} + +/// Resolve types for instruction operands (mirrors TS eachInstructionOperand). +fn apply_instruction_operands( + value: &InstructionValue, + identifiers: &mut [Identifier], + types: &mut Vec, + unifier: &Unifier, +) { + match value { + InstructionValue::LoadLocal { place, .. } | InstructionValue::LoadContext { place, .. } => { + resolve_identifier(place.identifier, identifiers, types, unifier); + } + InstructionValue::StoreLocal { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::StoreContext { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::StoreGlobal { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::Destructure { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::BinaryExpression { left, right, .. } => { + resolve_identifier(left.identifier, identifiers, types, unifier); + resolve_identifier(right.identifier, identifiers, types, unifier); + } + InstructionValue::UnaryExpression { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::TypeCastExpression { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::CallExpression { callee, args, .. } => { + resolve_identifier(callee.identifier, identifiers, types, unifier); + for arg in args { + match arg { + react_compiler_hir::PlaceOrSpread::Place(p) => { + resolve_identifier(p.identifier, identifiers, types, unifier); + } + react_compiler_hir::PlaceOrSpread::Spread(s) => { + resolve_identifier(s.place.identifier, identifiers, types, unifier); + } + } + } + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + resolve_identifier(receiver.identifier, identifiers, types, unifier); + resolve_identifier(property.identifier, identifiers, types, unifier); + for arg in args { + match arg { + react_compiler_hir::PlaceOrSpread::Place(p) => { + resolve_identifier(p.identifier, identifiers, types, unifier); + } + react_compiler_hir::PlaceOrSpread::Spread(s) => { + resolve_identifier(s.place.identifier, identifiers, types, unifier); + } + } + } + } + InstructionValue::NewExpression { callee, args, .. } => { + resolve_identifier(callee.identifier, identifiers, types, unifier); + for arg in args { + match arg { + react_compiler_hir::PlaceOrSpread::Place(p) => { + resolve_identifier(p.identifier, identifiers, types, unifier); + } + react_compiler_hir::PlaceOrSpread::Spread(s) => { + resolve_identifier(s.place.identifier, identifiers, types, unifier); + } + } + } + } + InstructionValue::TaggedTemplateExpression { tag, .. } => { + resolve_identifier(tag.identifier, identifiers, types, unifier); + // The template quasi's subexpressions are not separate operands in + // this HIR + } + InstructionValue::PropertyLoad { object, .. } => { + resolve_identifier(object.identifier, identifiers, types, unifier); + } + InstructionValue::PropertyStore { + object, value: val, .. + } => { + resolve_identifier(object.identifier, identifiers, types, unifier); + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::PropertyDelete { object, .. } => { + resolve_identifier(object.identifier, identifiers, types, unifier); + } + InstructionValue::ComputedLoad { + object, property, .. + } => { + resolve_identifier(object.identifier, identifiers, types, unifier); + resolve_identifier(property.identifier, identifiers, types, unifier); + } + InstructionValue::ComputedStore { + object, + property, + value: val, + .. + } => { + resolve_identifier(object.identifier, identifiers, types, unifier); + resolve_identifier(property.identifier, identifiers, types, unifier); + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::ComputedDelete { + object, property, .. + } => { + resolve_identifier(object.identifier, identifiers, types, unifier); + resolve_identifier(property.identifier, identifiers, types, unifier); + } + InstructionValue::ObjectExpression { properties, .. } => { + for prop in properties { + match prop { + ObjectPropertyOrSpread::Property(obj_prop) => { + resolve_identifier(obj_prop.place.identifier, identifiers, types, unifier); + if let ObjectPropertyKey::Computed { name } = &obj_prop.key { + resolve_identifier(name.identifier, identifiers, types, unifier); + } + } + ObjectPropertyOrSpread::Spread(spread) => { + resolve_identifier(spread.place.identifier, identifiers, types, unifier); + } + } + } + } + InstructionValue::ArrayExpression { elements, .. } => { + for elem in elements { + match elem { + react_compiler_hir::ArrayElement::Place(p) => { + resolve_identifier(p.identifier, identifiers, types, unifier); + } + react_compiler_hir::ArrayElement::Spread(s) => { + resolve_identifier(s.place.identifier, identifiers, types, unifier); + } + react_compiler_hir::ArrayElement::Hole => {} + } + } + } + InstructionValue::JsxExpression { + tag, + props, + children, + .. + } => { + if let react_compiler_hir::JsxTag::Place(p) = tag { + resolve_identifier(p.identifier, identifiers, types, unifier); + } + for attr in props { + match attr { + JsxAttribute::Attribute { place, .. } => { + resolve_identifier(place.identifier, identifiers, types, unifier); + } + JsxAttribute::SpreadAttribute { argument } => { + resolve_identifier(argument.identifier, identifiers, types, unifier); + } + } + } + if let Some(children) = children { + for child in children { + resolve_identifier(child.identifier, identifiers, types, unifier); + } + } + } + InstructionValue::JsxFragment { children, .. } => { + for child in children { + resolve_identifier(child.identifier, identifiers, types, unifier); + } + } + InstructionValue::FunctionExpression { .. } | InstructionValue::ObjectMethod { .. } => { + // Inner functions are handled separately via recursion in + // apply_function + } + InstructionValue::TemplateLiteral { subexprs, .. } => { + for sub in subexprs { + resolve_identifier(sub.identifier, identifiers, types, unifier); + } + } + InstructionValue::PrefixUpdate { + value: val, lvalue, .. + } + | InstructionValue::PostfixUpdate { + value: val, lvalue, .. + } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + resolve_identifier(lvalue.identifier, identifiers, types, unifier); + } + InstructionValue::Await { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::GetIterator { collection, .. } => { + resolve_identifier(collection.identifier, identifiers, types, unifier); + } + InstructionValue::IteratorNext { + iterator, + collection, + .. + } => { + resolve_identifier(iterator.identifier, identifiers, types, unifier); + resolve_identifier(collection.identifier, identifiers, types, unifier); + } + InstructionValue::NextPropertyOf { value: val, .. } => { + resolve_identifier(val.identifier, identifiers, types, unifier); + } + InstructionValue::FinishMemoize { decl, .. } => { + resolve_identifier(decl.identifier, identifiers, types, unifier); + } + InstructionValue::StartMemoize { deps, .. } => { + // Resolve types for deps with NamedLocal kind (matching TS + // eachInstructionOperand which yields dep.root.value for NamedLocal deps) + if let Some(deps) = deps { + for dep in deps { + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &dep.root { + resolve_identifier(value.identifier, identifiers, types, unifier); + } + } + } + } + InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::LoadGlobal { .. } + | InstructionValue::DeclareLocal { .. } + | InstructionValue::DeclareContext { .. } + | InstructionValue::RegExpLiteral { .. } + | InstructionValue::MetaProperty { .. } + | InstructionValue::Debugger { .. } + | InstructionValue::UnsupportedNode { .. } => { + // No operand places + } + } +} + +// ============================================================================= +// Unifier +// ============================================================================= + +struct Unifier { + substitutions: HashMap, + enable_treat_ref_like_identifiers_as_refs: bool, + enable_treat_set_identifiers_as_state_setters: bool, + custom_hook_type: Option, +} + +impl Unifier { + fn new( + enable_treat_ref_like_identifiers_as_refs: bool, + custom_hook_type: Option, + enable_treat_set_identifiers_as_state_setters: bool, + ) -> Self { + Unifier { + substitutions: HashMap::new(), + enable_treat_ref_like_identifiers_as_refs, + enable_treat_set_identifiers_as_state_setters, + custom_hook_type, + } + } + + fn unify( + &mut self, + t_a: Type, + t_b: Type, + shapes: &ShapeRegistry, + ) -> Result<(), CompilerDiagnostic> { + self.unify_impl(t_a, t_b, shapes) + } + + fn unify_impl( + &mut self, + t_a: Type, + t_b: Type, + shapes: &ShapeRegistry, + ) -> Result<(), CompilerDiagnostic> { + // Handle Property in the RHS position + if let Type::Property { + ref object_type, + ref object_name, + ref property_name, + } = t_b + { + // Check enableTreatRefLikeIdentifiersAsRefs + if self.enable_treat_ref_like_identifiers_as_refs + && is_ref_like_name(object_name, property_name) + { + self.unify_impl( + *object_type.clone(), + Type::Object { + shape_id: Some(BUILT_IN_USE_REF_ID.to_string()), + }, + shapes, + )?; + self.unify_impl( + t_a, + Type::Object { + shape_id: Some(BUILT_IN_REF_VALUE_ID.to_string()), + }, + shapes, + )?; + return Ok(()); + } + + // Resolve property type via the shapes registry + let resolved_object = self.get(object_type); + let property_type = resolve_property_type( + shapes, + &resolved_object, + property_name, + self.custom_hook_type.as_ref(), + ); + if let Some(property_type) = property_type { + self.unify_impl(t_a, property_type, shapes)?; + } + return Ok(()); + } + + if type_equals(&t_a, &t_b) { + return Ok(()); + } + + if let Type::TypeVar { .. } = &t_a { + self.bind_variable_to(t_a, t_b, shapes)?; + return Ok(()); + } + + if let Type::TypeVar { .. } = &t_b { + self.bind_variable_to(t_b, t_a, shapes)?; + return Ok(()); + } + + if let ( + Type::Function { + return_type: ret_a, + is_constructor: con_a, + .. + }, + Type::Function { + return_type: ret_b, + is_constructor: con_b, + .. + }, + ) = (&t_a, &t_b) + { + if con_a == con_b { + self.unify_impl(*ret_a.clone(), *ret_b.clone(), shapes)?; + } + } + Ok(()) + } + + fn bind_variable_to( + &mut self, + v: Type, + ty: Type, + shapes: &ShapeRegistry, + ) -> Result<(), CompilerDiagnostic> { + let v_id = match &v { + Type::TypeVar { id } => *id, + _ => return Ok(()), + }; + + if let Type::Poly = &ty { + // Ignore PolyType + return Ok(()); + } + + if let Some(existing) = self.substitutions.get(&v_id).cloned() { + self.unify_impl(existing, ty, shapes)?; + return Ok(()); + } + + if let Type::TypeVar { id: ty_id } = &ty { + if let Some(existing) = self.substitutions.get(ty_id).cloned() { + self.unify_impl(v, existing, shapes)?; + return Ok(()); + } + } + + if let Type::Phi { ref operands } = ty { + if operands.is_empty() { + return Err(CompilerDiagnostic { + category: ErrorCategory::Invariant, + reason: "there should be at least one operand".to_string(), + description: None, + details: vec![], + suggestions: None, + }); + } + + let mut candidate_type: Option = None; + for operand in operands { + let resolved = self.get(operand); + match &candidate_type { + None => { + candidate_type = Some(resolved); + } + Some(candidate) => { + if !type_equals(&resolved, candidate) { + let union_type = try_union_types(&resolved, candidate); + if let Some(union) = union_type { + candidate_type = Some(union); + } else { + candidate_type = None; + break; + } + } + // else same type, continue + } + } + } + + if let Some(candidate) = candidate_type { + self.unify_impl(v, candidate, shapes)?; + return Ok(()); + } + } + + if self.occurs_check(&v, &ty) { + let resolved_type = self.try_resolve_type(&v, &ty); + if let Some(resolved) = resolved_type { + self.substitutions.insert(v_id, resolved); + return Ok(()); + } + return Err(CompilerDiagnostic { + category: ErrorCategory::Invariant, + reason: "cycle detected".to_string(), + description: None, + details: vec![], + suggestions: None, + }); + } + + self.substitutions.insert(v_id, ty); + Ok(()) + } + + fn try_resolve_type(&mut self, v: &Type, ty: &Type) -> Option { + match ty { + Type::Phi { operands } => { + let mut new_operands = Vec::new(); + for operand in operands { + if let Type::TypeVar { id } = operand { + if let Type::TypeVar { id: v_id } = v { + if id == v_id { + continue; // skip self-reference + } + } + } + let resolved = self.try_resolve_type(v, operand)?; + new_operands.push(resolved); + } + Some(Type::Phi { + operands: new_operands, + }) + } + Type::TypeVar { id } => { + let substitution = self.get(ty); + if !type_equals(&substitution, ty) { + let resolved = self.try_resolve_type(v, &substitution)?; + self.substitutions.insert(*id, resolved.clone()); + Some(resolved) + } else { + Some(ty.clone()) + } + } + Type::Property { + object_type, + object_name, + property_name, + } => { + let resolved_obj = self.get(object_type); + let object_type = self.try_resolve_type(v, &resolved_obj)?; + Some(Type::Property { + object_type: Box::new(object_type), + object_name: object_name.clone(), + property_name: property_name.clone(), + }) + } + Type::Function { + shape_id, + return_type, + is_constructor, + } => { + let resolved_ret = self.get(return_type); + let return_type = self.try_resolve_type(v, &resolved_ret)?; + Some(Type::Function { + shape_id: shape_id.clone(), + return_type: Box::new(return_type), + is_constructor: *is_constructor, + }) + } + Type::ObjectMethod | Type::Object { .. } | Type::Primitive | Type::Poly => { + Some(ty.clone()) + } + } + } + + fn occurs_check(&self, v: &Type, ty: &Type) -> bool { + if type_equals(v, ty) { + return true; + } + + if let Type::TypeVar { id } = ty { + if let Some(sub) = self.substitutions.get(id) { + return self.occurs_check(v, sub); + } + } + + if let Type::Phi { operands } = ty { + return operands.iter().any(|o| self.occurs_check(v, o)); + } + + if let Type::Function { return_type, .. } = ty { + return self.occurs_check(v, return_type); + } + + false + } + + fn get(&self, ty: &Type) -> Type { + if let Type::TypeVar { id } = ty { + if let Some(sub) = self.substitutions.get(id) { + return self.get(sub); + } + } + + if let Type::Phi { operands } = ty { + return Type::Phi { + operands: operands.iter().map(|o| self.get(o)).collect(), + }; + } + + if let Type::Function { + is_constructor, + shape_id, + return_type, + } = ty + { + return Type::Function { + is_constructor: *is_constructor, + shape_id: shape_id.clone(), + return_type: Box::new(self.get(return_type)), + }; + } + + ty.clone() + } +} + +// ============================================================================= +// Union types helper +// ============================================================================= + +fn try_union_types(ty1: &Type, ty2: &Type) -> Option { + let (readonly_type, other_type) = if matches!(ty1, Type::Object { shape_id } if shape_id.as_deref() == Some(BUILT_IN_MIXED_READONLY_ID)) + { + (ty1, ty2) + } else if matches!(ty2, Type::Object { shape_id } if shape_id.as_deref() == Some(BUILT_IN_MIXED_READONLY_ID)) + { + (ty2, ty1) + } else { + return None; + }; + + if matches!(other_type, Type::Primitive) { + // Union(Primitive | MixedReadonly) = MixedReadonly + return Some(readonly_type.clone()); + } else if matches!(other_type, Type::Object { shape_id } if shape_id.as_deref() == Some(BUILT_IN_ARRAY_ID)) + { + // Union(Array | MixedReadonly) = Array + return Some(other_type.clone()); + } + + None +} diff --git a/crates/react_compiler_typeinference/src/lib.rs b/crates/react_compiler_typeinference/src/lib.rs new file mode 100644 index 000000000000..4003812f62e7 --- /dev/null +++ b/crates/react_compiler_typeinference/src/lib.rs @@ -0,0 +1,7 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod infer_types; + +pub use infer_types::infer_types; diff --git a/crates/react_compiler_utils/Cargo.toml b/crates/react_compiler_utils/Cargo.toml new file mode 100644 index 000000000000..7ef5353f1277 --- /dev/null +++ b/crates/react_compiler_utils/Cargo.toml @@ -0,0 +1,10 @@ +[package] +description = "Vendored React Compiler utilities from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_utils" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +indexmap = { workspace = true } diff --git a/crates/react_compiler_utils/src/disjoint_set.rs b/crates/react_compiler_utils/src/disjoint_set.rs new file mode 100644 index 000000000000..e4f744985560 --- /dev/null +++ b/crates/react_compiler_utils/src/disjoint_set.rs @@ -0,0 +1,147 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! A generic disjoint-set (union-find) data structure. +//! +//! Ported from TypeScript `src/Utils/DisjointSet.ts`. + +use std::{collections::HashSet, hash::Hash}; + +use indexmap::IndexMap; + +/// A Union-Find data structure for grouping items into disjoint sets. +/// +/// Corresponds to TS `DisjointSet` in `src/Utils/DisjointSet.ts`. +/// Uses `IndexMap` to preserve insertion order (matching TS `Map` behavior). +pub struct DisjointSet { + entries: IndexMap, +} + +impl DisjointSet { + pub fn new() -> Self { + DisjointSet { + entries: IndexMap::new(), + } + } + + /// Updates the graph to reflect that the given items form a set, + /// linking any previous sets that the items were part of into a single set. + /// + /// Corresponds to TS `union(items: Array): void`. + pub fn union(&mut self, items: &[K]) { + if items.is_empty() { + return; + } + let root = self.find(items[0]); + for &item in &items[1..] { + let item_root = self.find(item); + if item_root != root { + self.entries.insert(item_root, root); + } + } + } + + /// Find the root of the set containing `item`, with path compression. + /// If `item` is not in the set, it is inserted as its own root. + /// + /// Note: callers that need null/None semantics for missing items should + /// use `find_opt()` instead. + pub fn find(&mut self, item: K) -> K { + let parent = match self.entries.get(&item) { + Some(&p) => p, + None => { + self.entries.insert(item, item); + return item; + } + }; + if parent == item { + return item; + } + let root = self.find(parent); + self.entries.insert(item, root); + root + } + + /// Find the root of the set containing `item`, returning `None` if the item + /// was never added to the set. + /// + /// Corresponds to TS `find(item: T): T | null`. + pub fn find_opt(&mut self, item: K) -> Option { + if !self.entries.contains_key(&item) { + return None; + } + Some(self.find(item)) + } + + /// Returns true if the item is present in the set. + /// + /// Corresponds to TS `has(item: T): boolean`. + pub fn has(&self, item: K) -> bool { + self.entries.contains_key(&item) + } + + /// Forces the set into canonical form (all items pointing directly to their + /// root) and returns a map of items to their roots. + /// + /// Corresponds to TS `canonicalize(): Map`. + pub fn canonicalize(&mut self) -> IndexMap { + let mut result = IndexMap::new(); + let keys: Vec = self.entries.keys().copied().collect(); + for item in keys { + let root = self.find(item); + result.insert(item, root); + } + result + } + + /// Calls the provided callback once for each item in the disjoint set, + /// passing the item and the group root to which it belongs. + /// + /// Corresponds to TS `forEach(fn: (item: T, group: T) => void): void`. + pub fn for_each(&mut self, mut f: F) + where + F: FnMut(K, K), + { + let keys: Vec = self.entries.keys().copied().collect(); + for item in keys { + let group = self.find(item); + f(item, group); + } + } + + /// Groups all items by their root and returns the groups as a list of sets. + /// + /// Corresponds to TS `buildSets(): Array>`. + pub fn build_sets(&mut self) -> Vec> { + let mut group_to_index: IndexMap = IndexMap::new(); + let mut sets: Vec> = Vec::new(); + let keys: Vec = self.entries.keys().copied().collect(); + for item in keys { + let group = self.find(item); + let idx = match group_to_index.get(&group) { + Some(&idx) => idx, + None => { + let idx = sets.len(); + group_to_index.insert(group, idx); + sets.push(HashSet::new()); + idx + } + }; + sets[idx].insert(item); + } + sets + } + + /// Returns the number of items in the set. + /// + /// Corresponds to TS `get size(): number`. + pub fn len(&self) -> usize { + self.entries.len() + } + + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } +} diff --git a/crates/react_compiler_utils/src/lib.rs b/crates/react_compiler_utils/src/lib.rs new file mode 100644 index 000000000000..b25462be75df --- /dev/null +++ b/crates/react_compiler_utils/src/lib.rs @@ -0,0 +1,7 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod disjoint_set; + +pub use disjoint_set::DisjointSet; diff --git a/crates/react_compiler_validation/Cargo.toml b/crates/react_compiler_validation/Cargo.toml new file mode 100644 index 000000000000..bb3c0fe9e100 --- /dev/null +++ b/crates/react_compiler_validation/Cargo.toml @@ -0,0 +1,12 @@ +[package] +description = "Vendored React Compiler validation passes from facebook/react#36173" +edition = { workspace = true } +license = { workspace = true } +name = "react_compiler_validation" +repository = { workspace = true } +version = "0.1.0" + +[dependencies] +indexmap = { workspace = true } +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } diff --git a/crates/react_compiler_validation/src/lib.rs b/crates/react_compiler_validation/src/lib.rs new file mode 100644 index 000000000000..c4afa0bd8a8a --- /dev/null +++ b/crates/react_compiler_validation/src/lib.rs @@ -0,0 +1,37 @@ +// Vendored from facebook/react#36173. Keep upstream style intact to reduce +// merge drift. +#![allow(clippy::all)] + +pub mod validate_context_variable_lvalues; +pub mod validate_exhaustive_dependencies; +pub mod validate_hooks_usage; +pub mod validate_locals_not_reassigned_after_render; +pub mod validate_no_capitalized_calls; +pub mod validate_no_derived_computations_in_effects; +pub mod validate_no_freezing_known_mutable_functions; +pub mod validate_no_jsx_in_try_statement; +pub mod validate_no_ref_access_in_render; +pub mod validate_no_set_state_in_effects; +pub mod validate_no_set_state_in_render; +pub mod validate_preserved_manual_memoization; +pub mod validate_static_components; +pub mod validate_use_memo; + +pub use validate_context_variable_lvalues::{ + validate_context_variable_lvalues, validate_context_variable_lvalues_with_errors, +}; +pub use validate_exhaustive_dependencies::validate_exhaustive_dependencies; +pub use validate_hooks_usage::validate_hooks_usage; +pub use validate_locals_not_reassigned_after_render::validate_locals_not_reassigned_after_render; +pub use validate_no_capitalized_calls::validate_no_capitalized_calls; +pub use validate_no_derived_computations_in_effects::{ + validate_no_derived_computations_in_effects, validate_no_derived_computations_in_effects_exp, +}; +pub use validate_no_freezing_known_mutable_functions::validate_no_freezing_known_mutable_functions; +pub use validate_no_jsx_in_try_statement::validate_no_jsx_in_try_statement; +pub use validate_no_ref_access_in_render::validate_no_ref_access_in_render; +pub use validate_no_set_state_in_effects::validate_no_set_state_in_effects; +pub use validate_no_set_state_in_render::validate_no_set_state_in_render; +pub use validate_preserved_manual_memoization::validate_preserved_manual_memoization; +pub use validate_static_components::validate_static_components; +pub use validate_use_memo::validate_use_memo; diff --git a/crates/react_compiler_validation/src/validate_context_variable_lvalues.rs b/crates/react_compiler_validation/src/validate_context_variable_lvalues.rs new file mode 100644 index 000000000000..b700938b395f --- /dev/null +++ b/crates/react_compiler_validation/src/validate_context_variable_lvalues.rs @@ -0,0 +1,243 @@ +use std::collections::HashMap; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, ErrorCategory, +}; +use react_compiler_hir::{ + environment::Environment, + visitors::{each_instruction_value_lvalue, each_pattern_operand}, + FunctionId, HirFunction, Identifier, IdentifierId, InstructionValue, Place, +}; + +/// Variable reference kind: local, context, or destructure. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum VarRefKind { + Local, + Context, + Destructure, +} + +impl std::fmt::Display for VarRefKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VarRefKind::Local => write!(f, "local"), + VarRefKind::Context => write!(f, "context"), + VarRefKind::Destructure => write!(f, "destructure"), + } + } +} + +type IdentifierKinds = HashMap; + +/// Validates that context variable lvalues are used consistently. +/// +/// Port of ValidateContextVariableLValues.ts +pub fn validate_context_variable_lvalues( + func: &HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + validate_context_variable_lvalues_with_errors( + func, + &env.functions, + &env.identifiers, + &mut env.errors, + ) +} + +/// Like [`validate_context_variable_lvalues`], but writes diagnostics into the +/// provided `errors` instead of `env.errors`. Useful when the caller wants to +/// discard the diagnostics (e.g. when lowering is incomplete). +pub fn validate_context_variable_lvalues_with_errors( + func: &HirFunction, + functions: &[HirFunction], + identifiers: &[Identifier], + errors: &mut CompilerError, +) -> Result<(), CompilerDiagnostic> { + let mut identifier_kinds: IdentifierKinds = HashMap::new(); + validate_context_variable_lvalues_impl( + func, + &mut identifier_kinds, + functions, + identifiers, + errors, + ) +} + +fn validate_context_variable_lvalues_impl( + func: &HirFunction, + identifier_kinds: &mut IdentifierKinds, + functions: &[HirFunction], + identifiers: &[Identifier], + errors: &mut CompilerError, +) -> Result<(), CompilerDiagnostic> { + let mut inner_function_ids: Vec = Vec::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let value = &instr.value; + + match value { + InstructionValue::DeclareContext { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } => { + visit( + identifier_kinds, + &lvalue.place, + VarRefKind::Context, + identifiers, + errors, + )?; + } + InstructionValue::LoadContext { place, .. } => { + visit( + identifier_kinds, + place, + VarRefKind::Context, + identifiers, + errors, + )?; + } + InstructionValue::StoreLocal { lvalue, .. } + | InstructionValue::DeclareLocal { lvalue, .. } => { + visit( + identifier_kinds, + &lvalue.place, + VarRefKind::Local, + identifiers, + errors, + )?; + } + InstructionValue::LoadLocal { place, .. } => { + visit( + identifier_kinds, + place, + VarRefKind::Local, + identifiers, + errors, + )?; + } + InstructionValue::PostfixUpdate { lvalue, .. } + | InstructionValue::PrefixUpdate { lvalue, .. } => { + visit( + identifier_kinds, + lvalue, + VarRefKind::Local, + identifiers, + errors, + )?; + } + InstructionValue::Destructure { lvalue, .. } => { + for place in each_pattern_operand(&lvalue.pattern) { + visit( + identifier_kinds, + &place, + VarRefKind::Destructure, + identifiers, + errors, + )?; + } + } + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + inner_function_ids.push(lowered_func.func); + } + _ => { + for _ in each_instruction_value_lvalue(value) { + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Todo, + "ValidateContextVariableLValues: unhandled instruction variant", + None, + ) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: value.loc().copied(), + message: None, + identifier_name: None, + }, + ), + ); + } + } + } + } + } + + // Process inner functions after the block loop to avoid borrow conflicts + for func_id in inner_function_ids { + let inner_func = &functions[func_id.0 as usize]; + validate_context_variable_lvalues_impl( + inner_func, + identifier_kinds, + functions, + identifiers, + errors, + )?; + } + + Ok(()) +} + +/// Format a place like TS `printPlace()`: ` $` +fn format_place(place: &Place, identifiers: &[Identifier]) -> String { + let id = place.identifier; + let ident = &identifiers[id.0 as usize]; + let name = match &ident.name { + Some(n) => n.value().to_string(), + None => String::new(), + }; + format!("{} {}${}", place.effect, name, id.0) +} + +fn visit( + identifiers: &mut IdentifierKinds, + place: &Place, + kind: VarRefKind, + env_identifiers: &[Identifier], + errors: &mut CompilerError, +) -> Result<(), CompilerDiagnostic> { + if let Some((prev_place, prev_kind)) = identifiers.get(&place.identifier) { + let was_context = *prev_kind == VarRefKind::Context; + let is_context = kind == VarRefKind::Context; + if was_context != is_context { + if *prev_kind == VarRefKind::Destructure || kind == VarRefKind::Destructure { + let loc = if kind == VarRefKind::Destructure { + place.loc + } else { + prev_place.loc + }; + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Todo, + "Support destructuring of context variables", + None, + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: None, + identifier_name: None, + }), + ); + return Ok(()); + } + let place_str = format_place(place, env_identifiers); + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Expected all references to a variable to be consistently local or context \ + references", + Some(format!( + "Identifier {} is referenced as a {} variable, but was previously referenced \ + as a {} variable", + place_str, kind, prev_kind + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: place.loc, + message: Some(format!("this is {}", prev_kind)), + identifier_name: None, + })); + } + } + identifiers.insert(place.identifier, (place.clone(), kind)); + Ok(()) +} diff --git a/crates/react_compiler_validation/src/validate_exhaustive_dependencies.rs b/crates/react_compiler_validation/src/validate_exhaustive_dependencies.rs new file mode 100644 index 000000000000..91ba1859ce48 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_exhaustive_dependencies.rs @@ -0,0 +1,1825 @@ +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerSuggestion, CompilerSuggestionOperation, + ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, + environment_config::ExhaustiveEffectDepsMode, + visitors::{ + each_instruction_value_lvalue, each_instruction_value_operand_with_functions, + each_terminal_operand, + }, + ArrayElement, BlockId, DependencyPathEntry, HirFunction, Identifier, IdentifierId, + InstructionKind, InstructionValue, ManualMemoDependency, ManualMemoDependencyRoot, + NonLocalBinding, ParamPattern, Place, PlaceOrSpread, PropertyLiteral, Terminal, Type, +}; + +/// Port of ValidateExhaustiveDependencies.ts +/// +/// Validates that existing manual memoization is exhaustive and does not +/// have extraneous dependencies. The goal is to ensure auto-memoization +/// will not substantially change program behavior. +/// +/// Note: takes `&mut HirFunction` (deviating from the read-only validation +/// convention) because it sets `has_invalid_deps` on StartMemoize instructions +/// when validation errors are found, so that ValidatePreservedManualMemoization +/// can skip those blocks. +pub fn validate_exhaustive_dependencies( + func: &mut HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + let reactive = collect_reactive_identifiers(func, &env.functions); + let validate_memo = env.config.validate_exhaustive_memoization_dependencies; + let validate_effect = env.config.validate_exhaustive_effect_dependencies.clone(); + + let mut temporaries: HashMap = HashMap::new(); + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + temporaries.insert( + place.identifier, + Temporary::Local { + identifier: place.identifier, + path: Vec::new(), + context: false, + loc: place.loc, + }, + ); + } + + let mut start_memo: Option = None; + let mut memo_locals: HashSet = HashSet::new(); + + // Callbacks struct holding the mutable state + let mut callbacks = Callbacks { + start_memo: &mut start_memo, + memo_locals: &mut memo_locals, + validate_memo, + validate_effect: validate_effect.clone(), + reactive: &reactive, + diagnostics: Vec::new(), + invalid_memo_ids: HashSet::new(), + }; + + collect_dependencies( + func, + &env.identifiers, + &env.types, + &env.functions, + &mut temporaries, + &mut Some(&mut callbacks), + false, + )?; + + // Set has_invalid_deps on StartMemoize instructions that had validation errors + if !callbacks.invalid_memo_ids.is_empty() { + for instr in func.instructions.iter_mut() { + if let InstructionValue::StartMemoize { + manual_memo_id, + has_invalid_deps, + .. + } = &mut instr.value + { + if callbacks.invalid_memo_ids.contains(manual_memo_id) { + *has_invalid_deps = true; + } + } + } + } + + // Record all diagnostics on the environment + for diagnostic in callbacks.diagnostics { + env.record_diagnostic(diagnostic); + } + Ok(()) +} + +// ============================================================================= +// Internal types +// ============================================================================= + +/// Info extracted from a StartMemoize instruction +struct StartMemoInfo { + manual_memo_id: u32, + deps: Option>, + deps_loc: Option>, + #[allow(dead_code)] + loc: Option, +} + +/// A temporary value tracked during dependency collection +#[derive(Debug, Clone)] +enum Temporary { + Local { + identifier: IdentifierId, + path: Vec, + context: bool, + loc: Option, + }, + Global { + binding: NonLocalBinding, + }, + Aggregate { + dependencies: Vec, + loc: Option, + }, +} + +/// An inferred dependency (Local or Global) +#[derive(Debug, Clone)] +enum InferredDependency { + Local { + identifier: IdentifierId, + path: Vec, + #[allow(dead_code)] + context: bool, + loc: Option, + }, + Global { + binding: NonLocalBinding, + }, +} + +/// Hashable key for deduplicating inferred dependencies in a Set +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum InferredDependencyKey { + Local { + identifier: IdentifierId, + path_key: String, + }, + Global { + name: String, + }, +} + +fn dep_to_key(dep: &InferredDependency) -> InferredDependencyKey { + match dep { + InferredDependency::Local { + identifier, path, .. + } => InferredDependencyKey::Local { + identifier: *identifier, + path_key: path_to_string(path), + }, + InferredDependency::Global { binding } => InferredDependencyKey::Global { + name: binding.name().to_string(), + }, + } +} + +fn path_to_string(path: &[DependencyPathEntry]) -> String { + path.iter() + .map(|p| format!("{}{}", if p.optional { "?." } else { "." }, p.property)) + .collect::>() + .join("") +} + +/// Callbacks for StartMemoize/FinishMemoize/Effect events +struct Callbacks<'a> { + start_memo: &'a mut Option, + #[allow(dead_code)] + memo_locals: &'a mut HashSet, + validate_memo: bool, + validate_effect: ExhaustiveEffectDepsMode, + reactive: &'a HashSet, + diagnostics: Vec, + /// manual_memo_ids that had validation errors (to set has_invalid_deps) + invalid_memo_ids: HashSet, +} + +// ============================================================================= +// Helper: type checking functions +// ============================================================================= + +fn is_effect_event_function_type(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } if id == "BuiltInEffectEventFunction") +} + +fn is_stable_type(ty: &Type) -> bool { + match ty { + Type::Function { + shape_id: Some(id), .. + } => matches!( + id.as_str(), + "BuiltInSetState" + | "BuiltInSetActionState" + | "BuiltInDispatch" + | "BuiltInStartTransition" + | "BuiltInSetOptimistic" + ), + Type::Object { shape_id: Some(id) } => matches!(id.as_str(), "BuiltInUseRefId"), + _ => false, + } +} + +fn is_effect_hook(ty: &Type) -> bool { + matches!(ty, Type::Function { shape_id: Some(id), .. } + if id == "BuiltInUseEffectHook" + || id == "BuiltInUseLayoutEffectHook" + || id == "BuiltInUseInsertionEffectHook" + ) +} + +fn is_primitive_type(ty: &Type) -> bool { + matches!(ty, Type::Primitive) +} + +fn is_use_ref_type(ty: &Type) -> bool { + matches!(ty, Type::Object { shape_id: Some(id) } if id == "BuiltInUseRefId") +} + +fn get_identifier_type<'a>( + id: IdentifierId, + identifiers: &'a [Identifier], + types: &'a [Type], +) -> &'a Type { + let ident = &identifiers[id.0 as usize]; + &types[ident.type_.0 as usize] +} + +fn get_identifier_name(id: IdentifierId, identifiers: &[Identifier]) -> Option { + identifiers[id.0 as usize] + .name + .as_ref() + .map(|n| n.value().to_string()) +} + +// ============================================================================= +// Path helpers (matching TS areEqualPaths, isSubPath, +// isSubPathIgnoringOptionals) +// ============================================================================= + +fn are_equal_paths(a: &[DependencyPathEntry], b: &[DependencyPathEntry]) -> bool { + a.len() == b.len() + && a.iter() + .zip(b.iter()) + .all(|(ai, bi)| ai.property == bi.property && ai.optional == bi.optional) +} + +fn is_sub_path(subpath: &[DependencyPathEntry], path: &[DependencyPathEntry]) -> bool { + subpath.len() <= path.len() + && subpath + .iter() + .zip(path.iter()) + .all(|(a, b)| a.property == b.property && a.optional == b.optional) +} + +fn is_sub_path_ignoring_optionals( + subpath: &[DependencyPathEntry], + path: &[DependencyPathEntry], +) -> bool { + subpath.len() <= path.len() + && subpath + .iter() + .zip(path.iter()) + .all(|(a, b)| a.property == b.property) +} + +// ============================================================================= +// Collect reactive identifiers +// ============================================================================= + +fn collect_reactive_identifiers( + func: &HirFunction, + functions: &[HirFunction], +) -> HashSet { + let mut reactive = HashSet::new(); + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + // Check instruction lvalue + if instr.lvalue.reactive { + reactive.insert(instr.lvalue.identifier); + } + // Check inner lvalues (Destructure patterns, StoreLocal, DeclareLocal, etc.) + // Matches TS eachInstructionLValue which yields both instr.lvalue and + // eachInstructionValueLValue(instr.value) + for lvalue in each_instruction_value_lvalue(&instr.value) { + if lvalue.reactive { + reactive.insert(lvalue.identifier); + } + } + for operand in each_instruction_value_operand_with_functions(&instr.value, functions) { + if operand.reactive { + reactive.insert(operand.identifier); + } + } + } + for operand in each_terminal_operand(&block.terminal) { + if operand.reactive { + reactive.insert(operand.identifier); + } + } + } + reactive +} + +// ============================================================================= +// findOptionalPlaces +// ============================================================================= + +fn find_optional_places(func: &HirFunction) -> HashMap { + let mut optionals: HashMap = HashMap::new(); + let mut visited: HashSet = HashSet::new(); + + for (_block_id, block) in &func.body.blocks { + if visited.contains(&block.id) { + continue; + } + if let Terminal::Optional { + test, + fallthrough: optional_fallthrough, + optional, + .. + } = &block.terminal + { + visited.insert(block.id); + let mut test_block_id = *test; + let mut queue: Vec> = vec![Some(*optional)]; + + 'outer: loop { + let test_block = &func.body.blocks[&test_block_id]; + visited.insert(test_block.id); + match &test_block.terminal { + Terminal::Branch { + test: test_place, + consequent, + fallthrough, + .. + } => { + let is_optional = queue + .pop() + .expect("Expected an optional value for each optional test condition"); + if let Some(opt) = is_optional { + optionals.insert(test_place.identifier, opt); + } + if fallthrough == optional_fallthrough { + // Found the end of the optional chain + let consequent_block = &func.body.blocks[consequent]; + if let Some(last_id) = consequent_block.instructions.last() { + let last_instr = &func.instructions[last_id.0 as usize]; + if let InstructionValue::StoreLocal { value, .. } = + &last_instr.value + { + if let Some(opt) = is_optional { + optionals.insert(value.identifier, opt); + } + } + } + break 'outer; + } else { + test_block_id = *fallthrough; + } + } + Terminal::Optional { + optional: opt, + test: inner_test, + .. + } => { + queue.push(Some(*opt)); + test_block_id = *inner_test; + } + Terminal::Logical { + test: inner_test, .. + } + | Terminal::Ternary { + test: inner_test, .. + } => { + queue.push(None); + test_block_id = *inner_test; + } + Terminal::Sequence { + block: seq_block, .. + } => { + test_block_id = *seq_block; + } + Terminal::MaybeThrow { continuation, .. } => { + test_block_id = *continuation; + } + _ => { + // Unexpected terminal in optional — skip rather than panic + break 'outer; + } + } + } + // TS asserts queue.length === 0 here, but we skip the assertion + // to avoid panicking on edge cases. + } + } + + optionals +} + +// ============================================================================= +// Dependency collection +// ============================================================================= + +fn add_dependency( + dep: &Temporary, + dependencies: &mut Vec, + dep_keys: &mut HashSet, + locals: &HashSet, +) { + match dep { + Temporary::Aggregate { + dependencies: agg_deps, + .. + } => { + for d in agg_deps { + add_dependency_inferred(d, dependencies, dep_keys, locals); + } + } + Temporary::Global { binding } => { + let inferred = InferredDependency::Global { + binding: binding.clone(), + }; + let key = dep_to_key(&inferred); + if dep_keys.insert(key) { + dependencies.push(inferred); + } + } + Temporary::Local { + identifier, + path, + context, + loc, + } => { + if !locals.contains(identifier) { + let inferred = InferredDependency::Local { + identifier: *identifier, + path: path.clone(), + context: *context, + loc: *loc, + }; + let key = dep_to_key(&inferred); + if dep_keys.insert(key) { + dependencies.push(inferred); + } + } + } + } +} + +fn add_dependency_inferred( + dep: &InferredDependency, + dependencies: &mut Vec, + dep_keys: &mut HashSet, + locals: &HashSet, +) { + match dep { + InferredDependency::Global { .. } => { + let key = dep_to_key(dep); + if dep_keys.insert(key) { + dependencies.push(dep.clone()); + } + } + InferredDependency::Local { identifier, .. } => { + if !locals.contains(identifier) { + let key = dep_to_key(dep); + if dep_keys.insert(key) { + dependencies.push(dep.clone()); + } + } + } + } +} + +fn visit_candidate_dependency( + place: &Place, + temporaries: &HashMap, + dependencies: &mut Vec, + dep_keys: &mut HashSet, + locals: &HashSet, +) { + if let Some(dep) = temporaries.get(&place.identifier) { + add_dependency(dep, dependencies, dep_keys, locals); + } +} + +fn collect_dependencies( + func: &HirFunction, + identifiers: &[Identifier], + types: &[Type], + functions: &[HirFunction], + temporaries: &mut HashMap, + callbacks: &mut Option<&mut Callbacks<'_>>, + is_function_expression: bool, +) -> Result { + let optionals = find_optional_places(func); + let mut locals: HashSet = HashSet::new(); + + if is_function_expression { + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + locals.insert(place.identifier); + } + } + + let mut dependencies: Vec = Vec::new(); + let mut dep_keys: HashSet = HashSet::new(); + + // Saved state for when we're inside a memo block (StartMemoize..FinishMemoize). + // In TS, `dependencies` and `locals` are shared by reference between the main + // collection loop and the callbacks — StartMemoize clears them, FinishMemoize + // reads and clears them. We simulate this by saving/restoring. + let mut saved_dependencies: Option> = None; + let mut saved_dep_keys: Option> = None; + let mut saved_locals: Option> = None; + + for (_block_id, block) in &func.body.blocks { + // Process phis + for phi in &block.phis { + let mut deps: Vec = Vec::new(); + for (_pred_id, operand) in &phi.operands { + if let Some(dep) = temporaries.get(&operand.identifier) { + match dep { + Temporary::Aggregate { + dependencies: agg, .. + } => { + deps.extend(agg.iter().cloned()); + } + Temporary::Local { + identifier, + path, + context, + loc, + } => { + deps.push(InferredDependency::Local { + identifier: *identifier, + path: path.clone(), + context: *context, + loc: *loc, + }); + } + Temporary::Global { binding } => { + deps.push(InferredDependency::Global { + binding: binding.clone(), + }); + } + } + } + } + if deps.is_empty() { + continue; + } else if deps.len() == 1 { + let dep = &deps[0]; + match dep { + InferredDependency::Local { + identifier, + path, + context, + loc, + } => { + temporaries.insert( + phi.place.identifier, + Temporary::Local { + identifier: *identifier, + path: path.clone(), + context: *context, + loc: *loc, + }, + ); + } + InferredDependency::Global { binding } => { + temporaries.insert( + phi.place.identifier, + Temporary::Global { + binding: binding.clone(), + }, + ); + } + } + } else { + temporaries.insert( + phi.place.identifier, + Temporary::Aggregate { + dependencies: deps, + loc: None, + }, + ); + } + } + + // Process instructions + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + InstructionValue::LoadGlobal { binding, .. } => { + temporaries.insert( + lvalue_id, + Temporary::Global { + binding: binding.clone(), + }, + ); + } + InstructionValue::LoadContext { place, .. } + | InstructionValue::LoadLocal { place, .. } => { + if let Some(temp) = temporaries.get(&place.identifier).cloned() { + match &temp { + Temporary::Local { .. } => { + // Update loc to the load site + let mut updated = temp.clone(); + if let Temporary::Local { loc, .. } = &mut updated { + *loc = place.loc; + } + temporaries.insert(lvalue_id, updated); + } + _ => { + temporaries.insert(lvalue_id, temp); + } + } + if locals.contains(&place.identifier) { + locals.insert(lvalue_id); + } + } + } + InstructionValue::DeclareLocal { + lvalue: decl_lv, .. + } => { + temporaries.insert( + decl_lv.place.identifier, + Temporary::Local { + identifier: decl_lv.place.identifier, + path: Vec::new(), + context: false, + loc: decl_lv.place.loc, + }, + ); + locals.insert(decl_lv.place.identifier); + } + InstructionValue::StoreLocal { + lvalue: store_lv, + value: store_val, + .. + } => { + let has_name = identifiers[store_lv.place.identifier.0 as usize] + .name + .is_some(); + if !has_name { + // Unnamed: propagate temporary + if let Some(temp) = temporaries.get(&store_val.identifier).cloned() { + temporaries.insert(store_lv.place.identifier, temp); + } + } else { + // Named: visit the value and create a new local + visit_candidate_dependency( + store_val, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + if store_lv.kind != InstructionKind::Reassign { + temporaries.insert( + store_lv.place.identifier, + Temporary::Local { + identifier: store_lv.place.identifier, + path: Vec::new(), + context: false, + loc: store_lv.place.loc, + }, + ); + locals.insert(store_lv.place.identifier); + } + } + } + InstructionValue::DeclareContext { + lvalue: decl_lv, .. + } => { + temporaries.insert( + decl_lv.place.identifier, + Temporary::Local { + identifier: decl_lv.place.identifier, + path: Vec::new(), + context: true, + loc: decl_lv.place.loc, + }, + ); + } + InstructionValue::StoreContext { + lvalue: store_lv, + value: store_val, + .. + } => { + visit_candidate_dependency( + store_val, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + if store_lv.kind != InstructionKind::Reassign { + temporaries.insert( + store_lv.place.identifier, + Temporary::Local { + identifier: store_lv.place.identifier, + path: Vec::new(), + context: true, + loc: store_lv.place.loc, + }, + ); + locals.insert(store_lv.place.identifier); + } + } + InstructionValue::Destructure { + value: destr_val, + lvalue: destr_lv, + .. + } => { + visit_candidate_dependency( + destr_val, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + if destr_lv.kind != InstructionKind::Reassign { + for lv_place in each_instruction_value_lvalue(&instr.value) { + temporaries.insert( + lv_place.identifier, + Temporary::Local { + identifier: lv_place.identifier, + path: Vec::new(), + context: false, + loc: lv_place.loc, + }, + ); + locals.insert(lv_place.identifier); + } + } + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + // Number properties or ref.current: visit the object directly + let is_numeric = matches!(property, PropertyLiteral::Number(_)); + let is_ref_current = + is_use_ref_type(get_identifier_type(object.identifier, identifiers, types)) + && *property == PropertyLiteral::String("current".to_string()); + + if is_numeric || is_ref_current { + visit_candidate_dependency( + object, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + } else { + // Extend path + let obj_temp = temporaries.get(&object.identifier).cloned(); + if let Some(Temporary::Local { + identifier, + path, + context, + .. + }) = obj_temp + { + let optional = + optionals.get(&object.identifier).copied().unwrap_or(false); + let mut new_path = path.clone(); + new_path.push(DependencyPathEntry { + optional, + property: property.clone(), + loc: instr.value.loc().copied(), + }); + temporaries.insert( + lvalue_id, + Temporary::Local { + identifier, + path: new_path, + context, + loc: instr.value.loc().copied(), + }, + ); + } + } + } + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let inner_func = &functions[lowered_func.func.0 as usize]; + let function_deps = collect_dependencies( + inner_func, + identifiers, + types, + functions, + temporaries, + &mut None, + true, + )?; + temporaries.insert(lvalue_id, function_deps.clone()); + add_dependency(&function_deps, &mut dependencies, &mut dep_keys, &locals); + } + InstructionValue::StartMemoize { + manual_memo_id, + deps, + deps_loc, + loc, + .. + } => { + if let Some(cb) = callbacks.as_mut() { + // onStartMemoize — mirrors TS behavior of clearing dependencies and locals + *cb.start_memo = Some(StartMemoInfo { + manual_memo_id: *manual_memo_id, + deps: deps.clone(), + deps_loc: *deps_loc, + loc: *loc, + }); + // Save current state and clear, matching TS which clears the shared + // dependencies/locals sets on StartMemoize + saved_dependencies = Some(std::mem::take(&mut dependencies)); + saved_dep_keys = Some(std::mem::take(&mut dep_keys)); + saved_locals = Some(std::mem::take(&mut locals)); + } + } + InstructionValue::FinishMemoize { + manual_memo_id, + decl, + .. + } => { + if let Some(cb) = callbacks.as_mut() { + // onFinishMemoize — mirrors TS behavior + let sm = cb.start_memo.take(); + if let Some(sm) = sm { + assert_eq!( + sm.manual_memo_id, *manual_memo_id, + "Found FinishMemoize without corresponding StartMemoize" + ); + + if cb.validate_memo { + // Visit the decl to add it as a dependency candidate + // (matches TS: visitCandidateDependency(value.decl, ...)) + visit_candidate_dependency( + decl, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + + // Use ALL dependencies collected since StartMemoize cleared the + // set. This matches TS: `const + // inferred = Array.from(dependencies)` + let inferred: Vec = dependencies.clone(); + + let diagnostic = validate_dependencies( + inferred, + &sm.deps.unwrap_or_default(), + cb.reactive, + sm.deps_loc.unwrap_or(None), + ErrorCategory::MemoDependencies, + "all", + identifiers, + types, + )?; + if let Some(diag) = diagnostic { + cb.diagnostics.push(diag); + cb.invalid_memo_ids.insert(sm.manual_memo_id); + } + } + + // Restore saved state (matching TS: dependencies.clear(), + // locals.clear()) We restore instead of + // just clearing because we need the outer deps back + if let Some(saved) = saved_dependencies.take() { + // Merge current memo-block deps into the restored outer deps + let memo_deps = std::mem::replace(&mut dependencies, saved); + let _memo_keys = std::mem::replace( + &mut dep_keys, + saved_dep_keys.take().unwrap_or_default(), + ); + locals = saved_locals.take().unwrap_or_default(); + // Add memo deps to outer deps (they're still valid outer deps) + for d in memo_deps { + let key = dep_to_key(&d); + if dep_keys.insert(key) { + dependencies.push(d); + } + } + } + } + } + } + InstructionValue::ArrayExpression { elements, loc, .. } => { + let mut array_deps: Vec = Vec::new(); + let mut array_keys: HashSet = HashSet::new(); + let empty_locals = HashSet::new(); + for elem in elements { + let place = match elem { + ArrayElement::Place(p) => Some(p), + ArrayElement::Spread(s) => Some(&s.place), + ArrayElement::Hole => None, + }; + if let Some(place) = place { + // Visit with empty locals for manual deps + visit_candidate_dependency( + place, + temporaries, + &mut array_deps, + &mut array_keys, + &empty_locals, + ); + // Visit normally + visit_candidate_dependency( + place, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + } + } + temporaries.insert( + lvalue_id, + Temporary::Aggregate { + dependencies: array_deps, + loc: *loc, + }, + ); + } + InstructionValue::CallExpression { callee, args, .. } => { + // Check if this is an effect hook call + if let Some(cb) = callbacks.as_mut() { + let callee_ty = get_identifier_type(callee.identifier, identifiers, types); + if is_effect_hook(callee_ty) + && !matches!(cb.validate_effect, ExhaustiveEffectDepsMode::Off) + { + if args.len() >= 2 { + let fn_arg = match &args[0] { + PlaceOrSpread::Place(p) => Some(p), + _ => None, + }; + let deps_arg = match &args[1] { + PlaceOrSpread::Place(p) => Some(p), + _ => None, + }; + if let (Some(fn_place), Some(deps_place)) = (fn_arg, deps_arg) { + let fn_deps = temporaries.get(&fn_place.identifier).cloned(); + let manual_deps = + temporaries.get(&deps_place.identifier).cloned(); + if let ( + Some(Temporary::Aggregate { + dependencies: fn_dep_list, + .. + }), + Some(Temporary::Aggregate { + dependencies: manual_dep_list, + loc: manual_loc, + }), + ) = (fn_deps, manual_deps) + { + let effect_report_mode = match &cb.validate_effect { + ExhaustiveEffectDepsMode::All => "all", + ExhaustiveEffectDepsMode::MissingOnly => "missing-only", + ExhaustiveEffectDepsMode::ExtraOnly => "extra-only", + ExhaustiveEffectDepsMode::Off => unreachable!(), + }; + // Convert manual deps to ManualMemoDependency format + let manual_memo_deps: Vec = + manual_dep_list + .iter() + .map(|dep| match dep { + InferredDependency::Local { + identifier, + path, + loc, + .. + } => ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: Place { + identifier: *identifier, + effect: + react_compiler_hir::Effect::Read, + reactive: cb + .reactive + .contains(identifier), + loc: *loc, + }, + constant: false, + }, + path: path.clone(), + loc: *loc, + }, + InferredDependency::Global { binding } => { + ManualMemoDependency { + root: + ManualMemoDependencyRoot::Global { + identifier_name: binding + .name() + .to_string(), + }, + path: Vec::new(), + loc: None, + } + } + }) + .collect(); + + let diagnostic = validate_dependencies( + fn_dep_list, + &manual_memo_deps, + cb.reactive, + manual_loc, + ErrorCategory::EffectExhaustiveDependencies, + effect_report_mode, + identifiers, + types, + )?; + if let Some(diag) = diagnostic { + cb.diagnostics.push(diag); + } + } + } + } + } + } + + // Visit all operands except for MethodCall's property + for operand in + each_instruction_value_operand_with_functions(&instr.value, functions) + { + visit_candidate_dependency( + &operand, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + } + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + // Check if this is an effect hook call + if let Some(cb) = callbacks.as_mut() { + let prop_ty = get_identifier_type(property.identifier, identifiers, types); + if is_effect_hook(prop_ty) + && !matches!(cb.validate_effect, ExhaustiveEffectDepsMode::Off) + { + if args.len() >= 2 { + let fn_arg = match &args[0] { + PlaceOrSpread::Place(p) => Some(p), + _ => None, + }; + let deps_arg = match &args[1] { + PlaceOrSpread::Place(p) => Some(p), + _ => None, + }; + if let (Some(fn_place), Some(deps_place)) = (fn_arg, deps_arg) { + let fn_deps = temporaries.get(&fn_place.identifier).cloned(); + let manual_deps = + temporaries.get(&deps_place.identifier).cloned(); + if let ( + Some(Temporary::Aggregate { + dependencies: fn_dep_list, + .. + }), + Some(Temporary::Aggregate { + dependencies: manual_dep_list, + loc: manual_loc, + }), + ) = (fn_deps, manual_deps) + { + let effect_report_mode = match &cb.validate_effect { + ExhaustiveEffectDepsMode::All => "all", + ExhaustiveEffectDepsMode::MissingOnly => "missing-only", + ExhaustiveEffectDepsMode::ExtraOnly => "extra-only", + ExhaustiveEffectDepsMode::Off => unreachable!(), + }; + let manual_memo_deps: Vec = + manual_dep_list + .iter() + .map(|dep| match dep { + InferredDependency::Local { + identifier, + path, + loc, + .. + } => ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: Place { + identifier: *identifier, + effect: + react_compiler_hir::Effect::Read, + reactive: cb + .reactive + .contains(identifier), + loc: *loc, + }, + constant: false, + }, + path: path.clone(), + loc: *loc, + }, + InferredDependency::Global { binding } => { + ManualMemoDependency { + root: + ManualMemoDependencyRoot::Global { + identifier_name: binding + .name() + .to_string(), + }, + path: Vec::new(), + loc: None, + } + } + }) + .collect(); + + let diagnostic = validate_dependencies( + fn_dep_list, + &manual_memo_deps, + cb.reactive, + manual_loc, + ErrorCategory::EffectExhaustiveDependencies, + effect_report_mode, + identifiers, + types, + )?; + if let Some(diag) = diagnostic { + cb.diagnostics.push(diag); + } + } + } + } + } + } + + // Visit operands, skipping the method property itself + visit_candidate_dependency( + receiver, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + // Skip property — matches TS behavior + for arg in args { + let place = match arg { + PlaceOrSpread::Place(p) => p, + PlaceOrSpread::Spread(s) => &s.place, + }; + visit_candidate_dependency( + place, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + } + } + _ => { + // Default: visit all operands + for operand in + each_instruction_value_operand_with_functions(&instr.value, functions) + { + visit_candidate_dependency( + &operand, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + } + // Track lvalues as locals + for lv in each_instruction_lvalue_ids(&instr.value, lvalue_id) { + locals.insert(lv); + } + } + } + } + + // Terminal operands + for operand in &each_terminal_operand(&block.terminal) { + if optionals.contains_key(&operand.identifier) { + continue; + } + visit_candidate_dependency( + operand, + temporaries, + &mut dependencies, + &mut dep_keys, + &locals, + ); + } + } + + Ok(Temporary::Aggregate { + dependencies, + loc: None, + }) +} + +// ============================================================================= +// validateDependencies +// ============================================================================= + +fn validate_dependencies( + mut inferred: Vec, + manual_dependencies: &[ManualMemoDependency], + reactive: &HashSet, + manual_memo_loc: Option, + category: ErrorCategory, + exhaustive_deps_report_mode: &str, + identifiers: &[Identifier], + types: &[Type], +) -> Result, CompilerDiagnostic> { + // Sort dependencies by name and path + inferred.sort_by(|a, b| { + match (a, b) { + ( + InferredDependency::Global { binding: ab }, + InferredDependency::Global { binding: bb }, + ) => ab.name().cmp(bb.name()), + ( + InferredDependency::Local { + identifier: a_id, + path: a_path, + .. + }, + InferredDependency::Local { + identifier: b_id, + path: b_path, + .. + }, + ) => { + let a_name = get_identifier_name(*a_id, identifiers); + let b_name = get_identifier_name(*b_id, identifiers); + match (a_name.as_deref(), b_name.as_deref()) { + (Some(an), Some(bn)) => { + if *a_id != *b_id { + an.cmp(bn) + } else if a_path.len() != b_path.len() { + a_path.len().cmp(&b_path.len()) + } else { + // Compare path entries + for (ap, bp) in a_path.iter().zip(b_path.iter()) { + let a_opt = if ap.optional { 0i32 } else { 1 }; + let b_opt = if bp.optional { 0i32 } else { 1 }; + if a_opt != b_opt { + return a_opt.cmp(&b_opt); + } + let prop_cmp = + ap.property.to_string().cmp(&bp.property.to_string()); + if prop_cmp != std::cmp::Ordering::Equal { + return prop_cmp; + } + } + std::cmp::Ordering::Equal + } + } + _ => std::cmp::Ordering::Equal, + } + } + ( + InferredDependency::Global { binding: ab }, + InferredDependency::Local { + identifier: b_id, .. + }, + ) => { + let a_name = ab.name(); + let b_name = get_identifier_name(*b_id, identifiers); + match b_name.as_deref() { + Some(bn) => a_name.cmp(bn), + None => std::cmp::Ordering::Equal, + } + } + ( + InferredDependency::Local { + identifier: a_id, .. + }, + InferredDependency::Global { binding: bb }, + ) => { + let a_name = get_identifier_name(*a_id, identifiers); + let b_name = bb.name(); + match a_name.as_deref() { + Some(an) => an.cmp(b_name), + None => std::cmp::Ordering::Equal, + } + } + } + }); + + // Remove redundant inferred dependencies + // retainWhere logic: keep dep[ix] only if no earlier entry is equal or a + // subpath prefix Mirrors TS: retainWhere(inferred, (dep, ix) => { + // const match = inferred.findIndex(prevDep => isEqualTemporary(prevDep, dep) + // || ...); return match === -1 || match >= ix; + // }) + { + let snapshot = inferred.clone(); + let mut write_index = 0; + for ix in 0..snapshot.len() { + let dep = &snapshot[ix]; + let first_match = snapshot.iter().position(|prev_dep| { + is_equal_temporary(prev_dep, dep) + || (matches!( + (prev_dep, dep), + ( + InferredDependency::Local { .. }, + InferredDependency::Local { .. } + ) + ) && { + if let ( + InferredDependency::Local { + identifier: prev_id, + path: prev_path, + .. + }, + InferredDependency::Local { + identifier: dep_id, + path: dep_path, + .. + }, + ) = (prev_dep, dep) + { + prev_id == dep_id && is_sub_path(prev_path, dep_path) + } else { + false + } + }) + }); + + let keep = match first_match { + None => true, + Some(m) => m >= ix, + }; + if keep { + inferred[write_index] = snapshot[ix].clone(); + write_index += 1; + } + } + inferred.truncate(write_index); + } + + // Validate manual deps + let mut matched: HashSet = HashSet::new(); // indices into manual_dependencies + let mut missing: Vec<&InferredDependency> = Vec::new(); + let mut extra: Vec<&ManualMemoDependency> = Vec::new(); + + for inferred_dep in &inferred { + match inferred_dep { + InferredDependency::Global { binding } => { + for (i, manual_dep) in manual_dependencies.iter().enumerate() { + if let ManualMemoDependencyRoot::Global { identifier_name } = &manual_dep.root { + if identifier_name == binding.name() { + matched.insert(i); + extra.push(manual_dep); + } + } + } + continue; + } + InferredDependency::Local { + identifier, + path, + loc: _, + .. + } => { + // Skip effect event functions + let ty = get_identifier_type(*identifier, identifiers, types); + if is_effect_event_function_type(ty) { + continue; + } + + let mut has_matching = false; + for (i, manual_dep) in manual_dependencies.iter().enumerate() { + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &manual_dep.root { + if value.identifier == *identifier + && (are_equal_paths(&manual_dep.path, path) + || is_sub_path_ignoring_optionals(&manual_dep.path, path)) + { + has_matching = true; + matched.insert(i); + } + } + } + + if has_matching || is_optional_dependency(*identifier, reactive, identifiers, types) + { + continue; + } + + missing.push(inferred_dep); + } + } + } + + // Check for extra dependencies + for (i, dep) in manual_dependencies.iter().enumerate() { + if matched.contains(&i) { + continue; + } + if let ManualMemoDependencyRoot::NamedLocal { + constant, value, .. + } = &dep.root + { + if *constant { + let dep_ty = get_identifier_type(value.identifier, identifiers, types); + // Constant-folded primitives: skip + if !value.reactive && is_primitive_type(dep_ty) { + continue; + } + } + } + extra.push(dep); + } + + // Filter based on report mode + let filtered_missing: Vec<&InferredDependency> = if exhaustive_deps_report_mode == "extra-only" + { + Vec::new() + } else { + missing + }; + let filtered_extra: Vec<&ManualMemoDependency> = + if exhaustive_deps_report_mode == "missing-only" { + Vec::new() + } else { + extra + }; + + if filtered_missing.is_empty() && filtered_extra.is_empty() { + return Ok(None); + } + + // Build suggestion when we have valid index info (matches TS behavior) + let suggestion = manual_memo_loc.and_then(|loc| { + let start_index = loc.start.index?; + let end_index = loc.end.index?; + let text = format!( + "[{}]", + inferred + .iter() + .filter(|dep| { + match dep { + InferredDependency::Local { identifier, .. } => { + let ty = get_identifier_type(*identifier, identifiers, types); + !is_optional_dependency(*identifier, reactive, identifiers, types) + && !is_effect_event_function_type(ty) + } + InferredDependency::Global { .. } => false, + } + }) + .map(|dep| print_inferred_dependency(dep, identifiers)) + .collect::>() + .join(", ") + ); + Some(CompilerSuggestion { + op: CompilerSuggestionOperation::Replace, + range: (start_index as usize, end_index as usize), + description: "Update dependencies".to_string(), + text: Some(text), + }) + }); + + let mut diagnostic = create_diagnostic( + category, + &filtered_missing, + &filtered_extra, + suggestion, + identifiers, + )?; + + // Add detail items for missing deps + for dep in &filtered_missing { + if let InferredDependency::Local { + identifier, + path: _, + loc, + .. + } = dep + { + let mut hint = String::new(); + let ty = get_identifier_type(*identifier, identifiers, types); + if is_stable_type(ty) { + hint = ". Refs, setState functions, and other \"stable\" values generally do not \ + need to be added as dependencies, but this variable may change over time \ + to point to different values" + .to_string(); + } + let dep_str = print_inferred_dependency(dep, identifiers); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: *loc, + message: Some(format!("Missing dependency `{dep_str}`{hint}")), + identifier_name: None, + }); + } + } + + // Add detail items for extra deps + for dep in &filtered_extra { + match &dep.root { + ManualMemoDependencyRoot::Global { .. } => { + let dep_str = print_manual_memo_dependency(dep, identifiers); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: dep.loc.or(manual_memo_loc), + message: Some(format!( + "Unnecessary dependency `{dep_str}`. Values declared outside of a \ + component/hook should not be listed as dependencies as the component \ + will not re-render if they change" + )), + identifier_name: None, + }); + } + ManualMemoDependencyRoot::NamedLocal { value, .. } => { + // Check if there's a matching inferred dep + let matching_inferred = inferred.iter().find(|inf_dep| { + if let InferredDependency::Local { + identifier: inf_id, + path: inf_path, + .. + } = inf_dep + { + *inf_id == value.identifier + && is_sub_path_ignoring_optionals(inf_path, &dep.path) + } else { + false + } + }); + + if let Some(matching) = matching_inferred { + if let InferredDependency::Local { identifier, .. } = matching { + let matching_ty = get_identifier_type(*identifier, identifiers, types); + if is_effect_event_function_type(matching_ty) { + let dep_str = print_manual_memo_dependency(dep, identifiers); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: dep.loc.or(manual_memo_loc), + message: Some(format!( + "Functions returned from `useEffectEvent` must not be \ + included in the dependency array. Remove `{dep_str}` from \ + the dependencies." + )), + identifier_name: None, + }); + } else if !is_optional_dependency_inferred( + matching, + reactive, + identifiers, + types, + ) { + let dep_str = print_manual_memo_dependency(dep, identifiers); + let inferred_str = print_inferred_dependency(matching, identifiers); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: dep.loc.or(manual_memo_loc), + message: Some(format!( + "Overly precise dependency `{dep_str}`, use `{inferred_str}` \ + instead" + )), + identifier_name: None, + }); + } else { + let dep_str = print_manual_memo_dependency(dep, identifiers); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: dep.loc.or(manual_memo_loc), + message: Some(format!("Unnecessary dependency `{dep_str}`")), + identifier_name: None, + }); + } + } + } else { + let dep_str = print_manual_memo_dependency(dep, identifiers); + diagnostic.details.push(CompilerDiagnosticDetail::Error { + loc: dep.loc.or(manual_memo_loc), + message: Some(format!("Unnecessary dependency `{dep_str}`")), + identifier_name: None, + }); + } + } + } + } + + // Add hint showing inferred dependencies when a suggestion was generated + // (matches TS: only adds hint when suggestion != null, using suggestion.text) + if let Some(ref suggestions) = diagnostic.suggestions { + if let Some(suggestion) = suggestions.first() { + if let Some(ref text) = suggestion.text { + diagnostic.details.push(CompilerDiagnosticDetail::Hint { + message: format!("Inferred dependencies: `{text}`"), + }); + } + } + } + + Ok(Some(diagnostic)) +} + +// ============================================================================= +// Printing helpers +// ============================================================================= + +fn print_inferred_dependency(dep: &InferredDependency, identifiers: &[Identifier]) -> String { + match dep { + InferredDependency::Global { binding } => binding.name().to_string(), + InferredDependency::Local { + identifier, path, .. + } => { + let name = get_identifier_name(*identifier, identifiers) + .unwrap_or_else(|| "".to_string()); + let path_str: String = path + .iter() + .map(|p| format!("{}.{}", if p.optional { "?" } else { "" }, p.property)) + .collect(); + format!("{name}{path_str}") + } + } +} + +fn print_manual_memo_dependency(dep: &ManualMemoDependency, identifiers: &[Identifier]) -> String { + let name = match &dep.root { + ManualMemoDependencyRoot::Global { identifier_name } => identifier_name.clone(), + ManualMemoDependencyRoot::NamedLocal { value, .. } => { + get_identifier_name(value.identifier, identifiers) + .unwrap_or_else(|| "".to_string()) + } + }; + let path_str: String = dep + .path + .iter() + .map(|p| format!("{}.{}", if p.optional { "?" } else { "" }, p.property)) + .collect(); + format!("{name}{path_str}") +} + +// ============================================================================= +// Optional dependency check +// ============================================================================= + +fn is_optional_dependency( + identifier: IdentifierId, + reactive: &HashSet, + identifiers: &[Identifier], + types: &[Type], +) -> bool { + if reactive.contains(&identifier) { + return false; + } + let ty = get_identifier_type(identifier, identifiers, types); + is_stable_type(ty) || is_primitive_type(ty) +} + +fn is_optional_dependency_inferred( + dep: &InferredDependency, + reactive: &HashSet, + identifiers: &[Identifier], + types: &[Type], +) -> bool { + match dep { + InferredDependency::Local { identifier, .. } => { + is_optional_dependency(*identifier, reactive, identifiers, types) + } + InferredDependency::Global { .. } => false, + } +} + +// ============================================================================= +// Equality check for temporaries +// ============================================================================= + +fn is_equal_temporary(a: &InferredDependency, b: &InferredDependency) -> bool { + match (a, b) { + ( + InferredDependency::Global { binding: ab }, + InferredDependency::Global { binding: bb }, + ) => ab.name() == bb.name(), + ( + InferredDependency::Local { + identifier: a_id, + path: a_path, + .. + }, + InferredDependency::Local { + identifier: b_id, + path: b_path, + .. + }, + ) => a_id == b_id && are_equal_paths(a_path, b_path), + _ => false, + } +} + +// ============================================================================= +// createDiagnostic +// ============================================================================= + +fn create_diagnostic( + category: ErrorCategory, + missing: &[&InferredDependency], + extra: &[&ManualMemoDependency], + suggestion: Option, + _identifiers: &[Identifier], +) -> Result { + let missing_str = if !missing.is_empty() { + Some("missing") + } else { + None + }; + let extra_str = if !extra.is_empty() { + Some("extra") + } else { + None + }; + + let (reason, description) = match category { + ErrorCategory::MemoDependencies => { + let reason_parts: Vec<&str> = + [missing_str, extra_str].iter().filter_map(|x| *x).collect(); + let reason = format!("Found {} memoization dependencies", reason_parts.join("/")); + + let desc_parts: Vec<&str> = [ + if !missing.is_empty() { + Some( + "Missing dependencies can cause a value to update less often than it \ + should, resulting in stale UI", + ) + } else { + None + }, + if !extra.is_empty() { + Some( + "Extra dependencies can cause a value to update more often than it \ + should, resulting in performance problems such as excessive renders or \ + effects firing too often", + ) + } else { + None + }, + ] + .iter() + .filter_map(|x| *x) + .collect(); + let description = desc_parts.join(". "); + (reason, description) + } + ErrorCategory::EffectExhaustiveDependencies => { + let reason_parts: Vec<&str> = + [missing_str, extra_str].iter().filter_map(|x| *x).collect(); + let reason = format!("Found {} effect dependencies", reason_parts.join("/")); + + let desc_parts: Vec<&str> = [ + if !missing.is_empty() { + Some( + "Missing dependencies can cause an effect to fire less often than it \ + should", + ) + } else { + None + }, + if !extra.is_empty() { + Some( + "Extra dependencies can cause an effect to fire more often than it \ + should, resulting in performance problems such as excessive renders and \ + side effects", + ) + } else { + None + }, + ] + .iter() + .filter_map(|x| *x) + .collect(); + let description = desc_parts.join(". "); + (reason, description) + } + _ => { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + format!("Unexpected error category: {:?}", category), + None, + )); + } + }; + + Ok(CompilerDiagnostic { + category, + reason, + description: Some(description), + details: Vec::new(), + suggestions: suggestion.map(|s| vec![s]), + }) +} + +/// Collect lvalue identifier ids from instruction value (for the default +/// branch). Thin wrapper around canonical `each_instruction_value_lvalue` that +/// maps to ids. +fn each_instruction_lvalue_ids( + value: &InstructionValue, + lvalue_id: IdentifierId, +) -> Vec { + let mut ids = vec![lvalue_id]; + for place in each_instruction_value_lvalue(value) { + ids.push(place.identifier); + } + ids +} diff --git a/crates/react_compiler_validation/src/validate_hooks_usage.rs b/crates/react_compiler_validation/src/validate_hooks_usage.rs new file mode 100644 index 000000000000..f0427dc5d8c9 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_hooks_usage.rs @@ -0,0 +1,527 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates hooks usage rules. +//! +//! Port of ValidateHooksUsage.ts. +//! Ensures hooks are called unconditionally, not passed as values, +//! and not called dynamically. Also validates that hooks are not +//! called inside function expressions. + +use std::collections::HashMap; + +use indexmap::IndexMap; +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerError, CompilerErrorDetail, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + dominator::compute_unconditional_blocks, + environment::{is_hook_name, Environment}, + object_shape::HookKind, + visitors, + visitors::{each_pattern_operand, each_terminal_operand}, + FunctionId, HirFunction, Identifier, IdentifierId, InstructionValue, ParamPattern, Place, + PropertyLiteral, Type, +}; + +/// Value classification for hook validation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Kind { + Error, + KnownHook, + PotentialHook, + Global, + Local, +} + +fn join_kinds(a: Kind, b: Kind) -> Kind { + if a == Kind::Error || b == Kind::Error { + Kind::Error + } else if a == Kind::KnownHook || b == Kind::KnownHook { + Kind::KnownHook + } else if a == Kind::PotentialHook || b == Kind::PotentialHook { + Kind::PotentialHook + } else if a == Kind::Global || b == Kind::Global { + Kind::Global + } else { + Kind::Local + } +} + +fn get_kind_for_place( + place: &Place, + value_kinds: &HashMap, + identifiers: &[Identifier], +) -> Kind { + let known_kind = value_kinds.get(&place.identifier).copied(); + let ident = &identifiers[place.identifier.0 as usize]; + if let Some(ref name) = ident.name { + if is_hook_name(name.value()) { + return join_kinds(known_kind.unwrap_or(Kind::Local), Kind::PotentialHook); + } + } + known_kind.unwrap_or(Kind::Local) +} + +fn ident_is_hook_name(identifier_id: IdentifierId, identifiers: &[Identifier]) -> bool { + let ident = &identifiers[identifier_id.0 as usize]; + if let Some(ref name) = ident.name { + is_hook_name(name.value()) + } else { + false + } +} + +fn get_hook_kind_for_id<'a>( + identifier_id: IdentifierId, + identifiers: &[Identifier], + types: &[Type], + env: &'a Environment, +) -> Result, CompilerDiagnostic> { + let identifier = &identifiers[identifier_id.0 as usize]; + let ty = &types[identifier.type_.0 as usize]; + env.get_hook_kind_for_type(ty) +} + +fn visit_place( + place: &Place, + value_kinds: &HashMap, + errors_by_loc: &mut IndexMap, + env: &mut Environment, +) -> Result<(), CompilerError> { + let kind = value_kinds.get(&place.identifier).copied(); + if kind == Some(Kind::KnownHook) { + record_invalid_hook_usage_error(place, errors_by_loc, env)?; + } + Ok(()) +} + +fn record_conditional_hook_error( + place: &Place, + value_kinds: &mut HashMap, + errors_by_loc: &mut IndexMap, + env: &mut Environment, +) -> Result<(), CompilerError> { + value_kinds.insert(place.identifier, Kind::Error); + let reason = "Hooks must always be called in a consistent order, and may not be called conditionally. See the Rules of Hooks (https://react.dev/warnings/invalid-hook-call-warning)".to_string(); + if let Some(loc) = place.loc { + let previous = errors_by_loc.get(&loc); + if previous.is_none() || previous.unwrap().reason != reason { + errors_by_loc.insert( + loc, + CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason, + description: None, + loc: Some(loc), + suggestions: None, + }, + ); + } + } else { + env.record_error(CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason, + description: None, + loc: None, + suggestions: None, + })?; + } + Ok(()) +} + +fn record_invalid_hook_usage_error( + place: &Place, + errors_by_loc: &mut IndexMap, + env: &mut Environment, +) -> Result<(), CompilerError> { + let reason = "Hooks may not be referenced as normal values, they must be called. See https://react.dev/reference/rules/react-calls-components-and-hooks#never-pass-around-hooks-as-regular-values".to_string(); + if let Some(loc) = place.loc { + if !errors_by_loc.contains_key(&loc) { + errors_by_loc.insert( + loc, + CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason, + description: None, + loc: Some(loc), + suggestions: None, + }, + ); + } + } else { + env.record_error(CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason, + description: None, + loc: None, + suggestions: None, + })?; + } + Ok(()) +} + +fn record_dynamic_hook_usage_error( + place: &Place, + errors_by_loc: &mut IndexMap, + env: &mut Environment, +) -> Result<(), CompilerError> { + let reason = "Hooks must be the same function on every render, but this value may change over time to a different function. See https://react.dev/reference/rules/react-calls-components-and-hooks#dont-dynamically-use-hooks".to_string(); + if let Some(loc) = place.loc { + if !errors_by_loc.contains_key(&loc) { + errors_by_loc.insert( + loc, + CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason, + description: None, + loc: Some(loc), + suggestions: None, + }, + ); + } + } else { + env.record_error(CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason, + description: None, + loc: None, + suggestions: None, + })?; + } + Ok(()) +} + +/// Validates hooks usage rules for a function. +pub fn validate_hooks_usage( + func: &HirFunction, + env: &mut Environment, +) -> Result<(), react_compiler_diagnostics::CompilerDiagnostic> { + let unconditional_blocks = compute_unconditional_blocks(func, env.next_block_id().0)?; + let mut errors_by_loc: IndexMap = IndexMap::new(); + let mut value_kinds: HashMap = HashMap::new(); + + // Process params + for param in &func.params { + let place = match param { + ParamPattern::Place(p) => p, + ParamPattern::Spread(s) => &s.place, + }; + let kind = get_kind_for_place(place, &value_kinds, &env.identifiers); + value_kinds.insert(place.identifier, kind); + } + + // Process blocks + for (_block_id, block) in &func.body.blocks { + // Process phis + for phi in &block.phis { + let mut kind = if ident_is_hook_name(phi.place.identifier, &env.identifiers) { + Kind::PotentialHook + } else { + Kind::Local + }; + for (_, operand) in &phi.operands { + if let Some(&operand_kind) = value_kinds.get(&operand.identifier) { + kind = join_kinds(kind, operand_kind); + } + } + value_kinds.insert(phi.place.identifier, kind); + } + + // Process instructions + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + + match &instr.value { + InstructionValue::LoadGlobal { .. } => { + if get_hook_kind_for_id(lvalue_id, &env.identifiers, &env.types, env)?.is_some() + { + value_kinds.insert(lvalue_id, Kind::KnownHook); + } else { + value_kinds.insert(lvalue_id, Kind::Global); + } + } + InstructionValue::LoadContext { place, .. } + | InstructionValue::LoadLocal { place, .. } => { + visit_place(place, &value_kinds, &mut errors_by_loc, env)?; + let kind = get_kind_for_place(place, &value_kinds, &env.identifiers); + value_kinds.insert(lvalue_id, kind); + } + InstructionValue::StoreLocal { lvalue, value, .. } + | InstructionValue::StoreContext { lvalue, value, .. } => { + visit_place(value, &value_kinds, &mut errors_by_loc, env)?; + let kind = join_kinds( + get_kind_for_place(value, &value_kinds, &env.identifiers), + get_kind_for_place(&lvalue.place, &value_kinds, &env.identifiers), + ); + value_kinds.insert(lvalue.place.identifier, kind); + value_kinds.insert(lvalue_id, kind); + } + InstructionValue::ComputedLoad { object, .. } => { + visit_place(object, &value_kinds, &mut errors_by_loc, env)?; + let kind = get_kind_for_place(object, &value_kinds, &env.identifiers); + let lvalue_kind = + get_kind_for_place(&instr.lvalue, &value_kinds, &env.identifiers); + value_kinds.insert(lvalue_id, join_kinds(lvalue_kind, kind)); + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + let object_kind = get_kind_for_place(object, &value_kinds, &env.identifiers); + let is_hook_property = match property { + PropertyLiteral::String(s) => is_hook_name(s), + PropertyLiteral::Number(_) => false, + }; + let kind = match object_kind { + Kind::Error => Kind::Error, + Kind::KnownHook => { + if is_hook_property { + Kind::KnownHook + } else { + Kind::Local + } + } + Kind::PotentialHook => Kind::PotentialHook, + Kind::Global => { + if is_hook_property { + Kind::KnownHook + } else { + Kind::Global + } + } + Kind::Local => { + if is_hook_property { + Kind::PotentialHook + } else { + Kind::Local + } + } + }; + value_kinds.insert(lvalue_id, kind); + } + InstructionValue::CallExpression { callee, args, .. } => { + let callee_kind = get_kind_for_place(callee, &value_kinds, &env.identifiers); + let is_hook_callee = + callee_kind == Kind::KnownHook || callee_kind == Kind::PotentialHook; + if is_hook_callee && !unconditional_blocks.contains(&block.id) { + record_conditional_hook_error( + callee, + &mut value_kinds, + &mut errors_by_loc, + env, + )?; + } else if callee_kind == Kind::PotentialHook { + record_dynamic_hook_usage_error(callee, &mut errors_by_loc, env)?; + } + // Visit all operands except callee + for arg in args { + let place = match arg { + react_compiler_hir::PlaceOrSpread::Place(p) => p, + react_compiler_hir::PlaceOrSpread::Spread(s) => &s.place, + }; + visit_place(place, &value_kinds, &mut errors_by_loc, env)?; + } + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + let callee_kind = get_kind_for_place(property, &value_kinds, &env.identifiers); + let is_hook_callee = + callee_kind == Kind::KnownHook || callee_kind == Kind::PotentialHook; + if is_hook_callee && !unconditional_blocks.contains(&block.id) { + record_conditional_hook_error( + property, + &mut value_kinds, + &mut errors_by_loc, + env, + )?; + } else if callee_kind == Kind::PotentialHook { + record_dynamic_hook_usage_error(property, &mut errors_by_loc, env)?; + } + // Visit receiver and args (not property) + visit_place(receiver, &value_kinds, &mut errors_by_loc, env)?; + for arg in args { + let place = match arg { + react_compiler_hir::PlaceOrSpread::Place(p) => p, + react_compiler_hir::PlaceOrSpread::Spread(s) => &s.place, + }; + visit_place(place, &value_kinds, &mut errors_by_loc, env)?; + } + } + InstructionValue::Destructure { lvalue, value, .. } => { + visit_place(value, &value_kinds, &mut errors_by_loc, env)?; + let object_kind = get_kind_for_place(value, &value_kinds, &env.identifiers); + // Process instr.lvalue and all pattern operands (matching TS + // eachInstructionLValue) + let pattern_places = each_pattern_operand(&lvalue.pattern); + let all_lvalues = + std::iter::once(instr.lvalue.clone()).chain(pattern_places.into_iter()); + for place in all_lvalues { + let is_hook_property = + ident_is_hook_name(place.identifier, &env.identifiers); + let kind = match object_kind { + Kind::Error => Kind::Error, + Kind::KnownHook => Kind::KnownHook, + Kind::PotentialHook => Kind::PotentialHook, + Kind::Global => { + if is_hook_property { + Kind::KnownHook + } else { + Kind::Global + } + } + Kind::Local => { + if is_hook_property { + Kind::PotentialHook + } else { + Kind::Local + } + } + }; + value_kinds.insert(place.identifier, kind); + } + } + InstructionValue::ObjectMethod { lowered_func, .. } + | InstructionValue::FunctionExpression { lowered_func, .. } => { + visit_function_expression(env, lowered_func.func)?; + } + _ => { + // For all other instructions: visit operands, set lvalue kinds + // Matches TS which uses eachInstructionOperand + eachInstructionLValue + visit_all_operands(&instr.value, &value_kinds, &mut errors_by_loc, env)?; + // Set kind for instr.lvalue + let kind = get_kind_for_place(&instr.lvalue, &value_kinds, &env.identifiers); + value_kinds.insert(lvalue_id, kind); + // Also set kind for value-level lvalues (e.g. DeclareLocal, PrefixUpdate, + // PostfixUpdate) + for lv in visitors::each_instruction_value_lvalue(&instr.value) { + let lv_kind = get_kind_for_place(&lv, &value_kinds, &env.identifiers); + value_kinds.insert(lv.identifier, lv_kind); + } + } + } + } + + // Visit terminal operands + for place in each_terminal_operand(&block.terminal) { + visit_place(&place, &value_kinds, &mut errors_by_loc, env)?; + } + } + + // Record all accumulated errors (in insertion order, matching TS Map iteration) + for (_, error_detail) in errors_by_loc { + env.record_error(error_detail)?; + } + Ok(()) +} + +/// Visit a function expression to check for hook calls inside it. +/// Processes instructions in order, visiting nested functions immediately +/// (before processing subsequent calls) to match TS error ordering. +fn visit_function_expression( + env: &mut Environment, + func_id: FunctionId, +) -> Result<(), CompilerError> { + // Collect items in instruction order to process them sequentially. + // Each item is either a call to check or a nested function to visit. + enum Item { + Call(IdentifierId, Option), + NestedFunc(FunctionId), + } + + let func = &env.functions[func_id.0 as usize]; + let mut items: Vec = Vec::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::ObjectMethod { lowered_func, .. } + | InstructionValue::FunctionExpression { lowered_func, .. } => { + items.push(Item::NestedFunc(lowered_func.func)); + } + InstructionValue::CallExpression { callee, .. } => { + items.push(Item::Call(callee.identifier, callee.loc)); + } + InstructionValue::MethodCall { property, .. } => { + items.push(Item::Call(property.identifier, property.loc)); + } + _ => {} + } + } + } + + // Process items in instruction order (matching TS which visits nested + // functions immediately before processing subsequent calls) + for item in items { + match item { + Item::Call(identifier_id, loc) => { + let identifier = &env.identifiers[identifier_id.0 as usize]; + let ty = &env.types[identifier.type_.0 as usize]; + let hook_kind = env.get_hook_kind_for_type(ty).ok().flatten().cloned(); + if let Some(hook_kind) = hook_kind { + let description = format!( + "Cannot call {} within a function expression", + if hook_kind == HookKind::Custom { + "hook" + } else { + hook_kind_display(&hook_kind) + } + ); + env.record_error(CompilerErrorDetail { + category: ErrorCategory::Hooks, + reason: "Hooks must be called at the top level in the body of a function component or custom hook, and may not be called within function expressions. See the Rules of Hooks (https://react.dev/warnings/invalid-hook-call-warning)".to_string(), + description: Some(description), + loc, + suggestions: None, + })?; + } + } + Item::NestedFunc(nested_func_id) => { + visit_function_expression(env, nested_func_id)?; + } + } + } + Ok(()) +} + +fn hook_kind_display(kind: &HookKind) -> &'static str { + match kind { + HookKind::UseContext => "useContext", + HookKind::UseState => "useState", + HookKind::UseActionState => "useActionState", + HookKind::UseReducer => "useReducer", + HookKind::UseRef => "useRef", + HookKind::UseEffect => "useEffect", + HookKind::UseLayoutEffect => "useLayoutEffect", + HookKind::UseInsertionEffect => "useInsertionEffect", + HookKind::UseMemo => "useMemo", + HookKind::UseCallback => "useCallback", + HookKind::UseTransition => "useTransition", + HookKind::UseImperativeHandle => "useImperativeHandle", + HookKind::UseEffectEvent => "useEffectEvent", + HookKind::UseOptimistic => "useOptimistic", + HookKind::Custom => "hook", + } +} + +/// Visit all operands of an instruction value (generic fallback). +/// Uses the canonical `each_instruction_value_operand` from visitors. +fn visit_all_operands( + value: &InstructionValue, + value_kinds: &HashMap, + errors_by_loc: &mut IndexMap, + env: &mut Environment, +) -> Result<(), CompilerError> { + let operands = visitors::each_instruction_value_operand(value, &*env); + for place in &operands { + visit_place(place, value_kinds, errors_by_loc, env)?; + } + Ok(()) +} diff --git a/crates/react_compiler_validation/src/validate_locals_not_reassigned_after_render.rs b/crates/react_compiler_validation/src/validate_locals_not_reassigned_after_render.rs new file mode 100644 index 000000000000..67ecc7f0ae44 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_locals_not_reassigned_after_render.rs @@ -0,0 +1,283 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, + visitors::{ + each_instruction_lvalue_ids, each_instruction_value_operand, each_terminal_operand, + }, + Effect, HirFunction, Identifier, IdentifierId, IdentifierName, InstructionValue, Place, Type, +}; + +/// Validates that local variables cannot be reassigned after render. +/// This prevents a category of bugs in which a closure captures a +/// binding from one render but does not update. +pub fn validate_locals_not_reassigned_after_render(func: &HirFunction, env: &mut Environment) { + let mut context_variables: HashSet = HashSet::new(); + let mut diagnostics: Vec = Vec::new(); + + let reassignment = get_context_reassignment( + func, + &env.identifiers, + &env.types, + &env.functions, + env, + &mut context_variables, + false, + false, + &mut diagnostics, + ); + + // Record accumulated errors (from async function checks in inner functions) + // first + for diagnostic in diagnostics { + env.record_diagnostic(diagnostic); + } + + // Then record the top-level reassignment error if any + if let Some(reassignment_place) = reassignment { + let variable_name = format_variable_name(&reassignment_place, &env.identifiers); + env.record_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::Immutability, + "Cannot reassign variable after render completes", + Some(format!( + "Reassigning {} after render has completed can cause inconsistent behavior on \ + subsequent renders. Consider using state instead", + variable_name + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: reassignment_place.loc, + message: Some(format!( + "Cannot reassign {} after render completes", + variable_name + )), + identifier_name: None, + }), + ); + } +} + +/// Format a variable name for error messages. Uses the named identifier if +/// available, otherwise falls back to "variable". +fn format_variable_name(place: &Place, identifiers: &[Identifier]) -> String { + let identifier = &identifiers[place.identifier.0 as usize]; + match &identifier.name { + Some(IdentifierName::Named(name)) => format!("`{}`", name), + _ => "variable".to_string(), + } +} + +/// Recursively checks whether a function (or its dependencies) reassigns a +/// context variable. Returns the reassigned place if found, or None. +/// +/// Side effects: accumulates async-function reassignment diagnostics into +/// `diagnostics`. +fn get_context_reassignment( + func: &HirFunction, + identifiers: &[Identifier], + types: &[Type], + functions: &[HirFunction], + env: &Environment, + context_variables: &mut HashSet, + is_function_expression: bool, + is_async: bool, + diagnostics: &mut Vec, +) -> Option { + // Maps identifiers to the place that they reassign + let mut reassigning_functions: HashMap = HashMap::new(); + + for (_block_id, block) in &func.body.blocks { + for &instruction_id in &block.instructions { + let instr = &func.instructions[instruction_id.0 as usize]; + + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } + | InstructionValue::ObjectMethod { lowered_func, .. } => { + let inner_function = &functions[lowered_func.func.0 as usize]; + let inner_is_async = is_async || inner_function.is_async; + + // Recursively check the inner function + let mut reassignment = get_context_reassignment( + inner_function, + identifiers, + types, + functions, + env, + context_variables, + true, + inner_is_async, + diagnostics, + ); + + // If the function itself doesn't reassign, check if one of its + // dependencies (operands) is a reassigning function + if reassignment.is_none() { + for context_place in &inner_function.context { + if let Some(reassignment_place) = + reassigning_functions.get(&context_place.identifier) + { + reassignment = Some(reassignment_place.clone()); + break; + } + } + } + + // If the function or its dependencies reassign, handle it + if let Some(ref reassignment_place) = reassignment { + if inner_is_async { + // Async functions that reassign get an immediate error + let variable_name = + format_variable_name(reassignment_place, identifiers); + diagnostics.push( + CompilerDiagnostic::new( + ErrorCategory::Immutability, + "Cannot reassign variable in async function", + Some( + "Reassigning a variable in an async function can cause \ + inconsistent behavior on subsequent renders. Consider \ + using state instead" + .to_string(), + ), + ) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: reassignment_place.loc, + message: Some(format!("Cannot reassign {}", variable_name)), + identifier_name: None, + }, + ), + ); + // Return null (don't propagate further) — matches + // TS behavior + } else { + // Propagate reassignment info on the lvalue + reassigning_functions + .insert(instr.lvalue.identifier, reassignment_place.clone()); + } + } + } + + InstructionValue::StoreLocal { lvalue, value, .. } => { + if let Some(reassignment_place) = reassigning_functions.get(&value.identifier) { + let reassignment_place = reassignment_place.clone(); + reassigning_functions + .insert(lvalue.place.identifier, reassignment_place.clone()); + reassigning_functions.insert(instr.lvalue.identifier, reassignment_place); + } + } + + InstructionValue::LoadLocal { place, .. } => { + if let Some(reassignment_place) = reassigning_functions.get(&place.identifier) { + reassigning_functions + .insert(instr.lvalue.identifier, reassignment_place.clone()); + } + } + + InstructionValue::DeclareContext { lvalue, .. } => { + if !is_function_expression { + context_variables.insert(lvalue.place.identifier); + } + } + + InstructionValue::StoreContext { lvalue, value, .. } => { + // If we're inside a function expression and the target is a + // context variable from the outer scope, this is a reassignment + if is_function_expression + && context_variables.contains(&lvalue.place.identifier) + { + return Some(lvalue.place.clone()); + } + + // In the outer function, track context variables + if !is_function_expression { + context_variables.insert(lvalue.place.identifier); + } + + // Propagate reassigning function info through StoreContext + if let Some(reassignment_place) = reassigning_functions.get(&value.identifier) { + let reassignment_place = reassignment_place.clone(); + reassigning_functions + .insert(lvalue.place.identifier, reassignment_place.clone()); + reassigning_functions.insert(instr.lvalue.identifier, reassignment_place); + } + } + + _ => { + // For calls with noAlias signatures, only check the callee/receiver + // (not args) to avoid false positives from callbacks that reassign + // context variables. + let operands: Vec = match &instr.value { + InstructionValue::CallExpression { callee, .. } => { + if env.has_no_alias_signature(callee.identifier) { + vec![callee.clone()] + } else { + each_instruction_value_operand(&instr.value, env) + } + } + InstructionValue::MethodCall { + receiver, property, .. + } => { + if env.has_no_alias_signature(property.identifier) { + vec![receiver.clone(), property.clone()] + } else { + each_instruction_value_operand(&instr.value, env) + } + } + InstructionValue::TaggedTemplateExpression { tag, .. } => { + if env.has_no_alias_signature(tag.identifier) { + vec![tag.clone()] + } else { + each_instruction_value_operand(&instr.value, env) + } + } + _ => each_instruction_value_operand(&instr.value, env), + }; + + for operand in &operands { + // Invariant: effects must be inferred before this pass runs + assert!( + operand.effect != Effect::Unknown, + "Expected effects to be inferred prior to \ + ValidateLocalsNotReassignedAfterRender" + ); + + if let Some(reassignment_place) = + reassigning_functions.get(&operand.identifier).cloned() + { + if operand.effect == Effect::Freeze { + // Functions that reassign local variables are inherently + // mutable and unsafe to pass where a frozen value is expected. + return Some(reassignment_place); + } else { + // If the operand is not frozen but does reassign, then the + // lvalues of the instruction could also be reassigning + for lvalue_id in each_instruction_lvalue_ids(instr) { + reassigning_functions + .insert(lvalue_id, reassignment_place.clone()); + } + } + } + } + } + } + } + + // Check terminal operands for reassigning functions + for operand in each_terminal_operand(&block.terminal) { + if let Some(reassignment_place) = reassigning_functions.get(&operand.identifier) { + return Some(reassignment_place.clone()); + } + } + } + + None +} diff --git a/crates/react_compiler_validation/src/validate_no_capitalized_calls.rs b/crates/react_compiler_validation/src/validate_no_capitalized_calls.rs new file mode 100644 index 000000000000..893258bfee32 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_capitalized_calls.rs @@ -0,0 +1,88 @@ +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{CompilerError, CompilerErrorDetail, ErrorCategory}; +use react_compiler_hir::{ + environment::Environment, HirFunction, IdentifierId, InstructionValue, PropertyLiteral, +}; + +/// Validates that capitalized functions are not called directly (they should be +/// rendered as JSX). +/// +/// Port of ValidateNoCapitalizedCalls.ts. +pub fn validate_no_capitalized_calls( + func: &HirFunction, + env: &mut Environment, +) -> Result<(), CompilerError> { + // Build the allow list from global registry keys + config entries + let mut allow_list: HashSet = env.globals().keys().cloned().collect(); + if let Some(config_entries) = &env.config.validate_no_capitalized_calls { + for entry in config_entries { + allow_list.insert(entry.clone()); + } + } + + let mut capital_load_globals: HashMap = HashMap::new(); + let mut capitalized_properties: HashMap = HashMap::new(); + + let reason = "Capitalized functions are reserved for components, which must be invoked with \ + JSX. If this is a component, render it with JSX. Otherwise, ensure that it has \ + no hook calls and rename it to begin with a lowercase letter. Alternatively, if \ + you know for a fact that this function is not a component, you can allowlist it \ + via the compiler config"; + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + let value = &instr.value; + + match value { + InstructionValue::LoadGlobal { binding, .. } => { + let name = binding.name(); + if !name.is_empty() + && name.starts_with(|c: char| c.is_ascii_uppercase()) + // We don't want to flag CONSTANTS() + && name != name.to_uppercase() + && !allow_list.contains(name) + { + capital_load_globals.insert(lvalue_id, name.to_string()); + } + } + InstructionValue::CallExpression { callee, loc, .. } => { + let callee_id = callee.identifier; + if let Some(callee_name) = capital_load_globals.get(&callee_id) { + env.record_error(CompilerErrorDetail { + category: ErrorCategory::CapitalizedCalls, + reason: reason.to_string(), + description: Some(format!("{callee_name} may be a component")), + loc: *loc, + suggestions: None, + })?; + continue; + } + } + InstructionValue::PropertyLoad { property, .. } => { + if let PropertyLiteral::String(prop_name) = property { + if prop_name.starts_with(|c: char| c.is_ascii_uppercase()) { + capitalized_properties.insert(lvalue_id, prop_name.clone()); + } + } + } + InstructionValue::MethodCall { property, loc, .. } => { + let property_id = property.identifier; + if let Some(prop_name) = capitalized_properties.get(&property_id) { + env.record_error(CompilerErrorDetail { + category: ErrorCategory::CapitalizedCalls, + reason: reason.to_string(), + description: Some(format!("{prop_name} may be a component")), + loc: *loc, + suggestions: None, + })?; + } + } + _ => {} + } + } + } + Ok(()) +} diff --git a/crates/react_compiler_validation/src/validate_no_derived_computations_in_effects.rs b/crates/react_compiler_validation/src/validate_no_derived_computations_in_effects.rs new file mode 100644 index 000000000000..ea92cf61f49c --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_derived_computations_in_effects.rs @@ -0,0 +1,1461 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates that useEffect is not used for derived computations which +//! could/should be performed in render. +//! +//! See https://react.dev/learn/you-might-not-need-an-effect#updating-state-based-on-props-or-state +//! +//! Port of ValidateNoDerivedComputationsInEffects_exp.ts. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, CompilerErrorDetail, ErrorCategory, +}; +use react_compiler_hir::{ + environment::Environment, + is_set_state_type, is_use_effect_hook_type, is_use_ref_type, is_use_state_type, + visitors::{ + each_instruction_lvalue_ids, each_instruction_operand as canonical_each_instruction_operand, + }, + ArrayElement, BlockId, Effect, EvaluationOrder, FunctionId, HirFunction, Identifier, + IdentifierId, IdentifierName, InstructionValue, ParamPattern, PlaceOrSpread, ReactFunctionType, + ReturnVariant, SourceLocation, Type, +}; + +/// Get the user-visible name for an identifier, matching Babel's +/// loc.identifierName behavior. First checks the identifier's own name, +/// then falls back to extracting the name from the source code at the +/// given source location. This handles SSA identifiers whose names were +/// lost during compiler passes. +fn get_identifier_name_with_loc( + id: IdentifierId, + identifiers: &[Identifier], + loc: &Option, + source_code: Option<&str>, +) -> Option { + let ident = &identifiers[id.0 as usize]; + match &ident.name { + Some(IdentifierName::Named(name)) | Some(IdentifierName::Promoted(name)) => { + return Some(name.clone()); + } + _ => {} + } + // Fall back: find another identifier with the same declaration_id that has a + // name. + let decl_id = ident.declaration_id; + for other in identifiers { + if other.declaration_id == decl_id { + match &other.name { + Some(IdentifierName::Named(name)) | Some(IdentifierName::Promoted(name)) => { + return Some(name.clone()); + } + _ => {} + } + } + } + // Fall back to extracting from source code using UTF-16 code unit indices. + // Babel/JS positions use UTF-16 code unit offsets, but Rust strings are UTF-8, + // so we need to convert between the two. + if let (Some(loc), Some(code)) = (loc, source_code) { + let start_utf16 = loc.start.index? as usize; + let end_utf16 = loc.end.index? as usize; + if start_utf16 < end_utf16 { + // Convert UTF-16 code unit offsets to UTF-8 byte offsets + let mut utf16_pos = 0usize; + let mut byte_start = None; + let mut byte_end = None; + for (byte_idx, ch) in code.char_indices() { + if utf16_pos == start_utf16 { + byte_start = Some(byte_idx); + } + if utf16_pos == end_utf16 { + byte_end = Some(byte_idx); + break; + } + utf16_pos += ch.len_utf16(); + } + // Handle end at the very end of string + if utf16_pos == end_utf16 && byte_end.is_none() { + byte_end = Some(code.len()); + } + if let (Some(start), Some(end)) = (byte_start, byte_end) { + let slice = &code[start..end]; + if !slice.is_empty() + && slice + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '$') + { + return Some(slice.to_string()); + } + } + } + } + None +} + +const MAX_FIXPOINT_ITERATIONS: usize = 100; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TypeOfValue { + Ignored, + FromProps, + FromState, + FromPropsAndState, +} + +#[derive(Debug, Clone)] +struct DerivationMetadata { + type_of_value: TypeOfValue, + place_identifier: IdentifierId, + place_name: Option, + source_ids: indexmap::IndexSet, + is_state_source: bool, +} + +/// Metadata about a useEffect call site. +struct EffectMetadata { + effect_func_id: FunctionId, + dep_elements: Vec, +} + +#[derive(Debug, Clone)] +struct DepElement { + identifier: IdentifierId, + loc: Option, +} + +struct ValidationContext { + /// Map from lvalue identifier to the FunctionId of function expressions + functions: HashMap, + /// Map from lvalue identifier to ArrayExpression elements (candidate deps) + candidate_dependencies: HashMap>, + derivation_cache: DerivationCache, + effects_cache: HashMap, + set_state_loads: HashMap>, + set_state_usages: HashMap>, +} + +/// A hashable key for SourceLocation to use in HashSet +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct LocKey { + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, +} + +impl LocKey { + fn from_loc(loc: &Option) -> Self { + match loc { + Some(loc) => LocKey { + start_line: loc.start.line, + start_col: loc.start.column, + end_line: loc.end.line, + end_col: loc.end.column, + }, + None => LocKey { + start_line: 0, + start_col: 0, + end_line: 0, + end_col: 0, + }, + } + } +} + +#[derive(Debug, Clone)] +struct DerivationCache { + has_changes: bool, + cache: HashMap, + previous_cache: Option>, +} + +impl DerivationCache { + fn new() -> Self { + DerivationCache { + has_changes: false, + cache: HashMap::new(), + previous_cache: None, + } + } + + fn take_snapshot(&mut self) { + let mut prev = HashMap::new(); + for (key, value) in &self.cache { + prev.insert( + *key, + DerivationMetadata { + place_identifier: value.place_identifier, + place_name: value.place_name.clone(), + source_ids: value.source_ids.clone(), + type_of_value: value.type_of_value, + is_state_source: value.is_state_source, + }, + ); + } + self.previous_cache = Some(prev); + } + + fn check_for_changes(&mut self) { + let prev = match &self.previous_cache { + Some(p) => p, + None => { + self.has_changes = true; + return; + } + }; + + for (key, value) in &self.cache { + match prev.get(key) { + None => { + self.has_changes = true; + return; + } + Some(prev_value) => { + if !is_derivation_equal(prev_value, value) { + self.has_changes = true; + return; + } + } + } + } + + if self.cache.len() != prev.len() { + self.has_changes = true; + return; + } + + self.has_changes = false; + } + + fn snapshot(&mut self) -> bool { + let has_changes = self.has_changes; + self.has_changes = false; + has_changes + } + + fn add_derivation_entry( + &mut self, + derived_id: IdentifierId, + derived_name: Option, + source_ids: indexmap::IndexSet, + type_of_value: TypeOfValue, + is_state_source: bool, + ) { + let mut final_is_source = is_state_source; + if !final_is_source { + for source_id in &source_ids { + if let Some(source_metadata) = self.cache.get(source_id) { + if source_metadata.is_state_source + && !matches!(&source_metadata.place_name, Some(IdentifierName::Named(_))) + { + final_is_source = true; + break; + } + } + } + } + + self.cache.insert( + derived_id, + DerivationMetadata { + place_identifier: derived_id, + place_name: derived_name, + source_ids, + type_of_value, + is_state_source: final_is_source, + }, + ); + } +} + +fn is_derivation_equal(a: &DerivationMetadata, b: &DerivationMetadata) -> bool { + if a.type_of_value != b.type_of_value { + return false; + } + if a.source_ids.len() != b.source_ids.len() { + return false; + } + for id in &a.source_ids { + if !b.source_ids.contains(id) { + return false; + } + } + true +} + +fn join_value(lvalue_type: TypeOfValue, value_type: TypeOfValue) -> TypeOfValue { + if lvalue_type == TypeOfValue::Ignored { + return value_type; + } + if value_type == TypeOfValue::Ignored { + return lvalue_type; + } + if lvalue_type == value_type { + return lvalue_type; + } + TypeOfValue::FromPropsAndState +} + +fn get_root_set_state( + key: IdentifierId, + loads: &HashMap>, + visited: &mut HashSet, +) -> Option { + if visited.contains(&key) { + return None; + } + visited.insert(key); + + match loads.get(&key) { + None => None, + Some(None) => Some(key), + Some(Some(parent_id)) => get_root_set_state(*parent_id, loads, visited), + } +} + +fn maybe_record_set_state_for_instr( + instr: &react_compiler_hir::Instruction, + env: &Environment, + set_state_loads: &mut HashMap>, + set_state_usages: &mut HashMap>, +) { + let identifiers = &env.identifiers; + let types = &env.types; + + let all_lvalues = each_instruction_lvalue_ids(instr); + for &lvalue_id in &all_lvalues { + // Check if this is a LoadLocal from a known setState + if let InstructionValue::LoadLocal { place, .. } = &instr.value { + if set_state_loads.contains_key(&place.identifier) { + set_state_loads.insert(lvalue_id, Some(place.identifier)); + } else { + // Only check root setState if not a LoadLocal from a known chain + let lvalue_ident = &identifiers[lvalue_id.0 as usize]; + let lvalue_ty = &types[lvalue_ident.type_.0 as usize]; + if is_set_state_type(lvalue_ty) { + set_state_loads.insert(lvalue_id, None); + } + } + } else { + // Check if lvalue is a setState type (root setState) + let lvalue_ident = &identifiers[lvalue_id.0 as usize]; + let lvalue_ty = &types[lvalue_ident.type_.0 as usize]; + if is_set_state_type(lvalue_ty) { + set_state_loads.insert(lvalue_id, None); + } + } + + let root = get_root_set_state(lvalue_id, set_state_loads, &mut HashSet::new()); + if let Some(root_id) = root { + set_state_usages.entry(root_id).or_insert_with(|| { + let mut set = HashSet::new(); + set.insert(LocKey::from_loc(&instr.lvalue.loc)); + set + }); + } + } +} + +fn is_mutable_at( + env: &Environment, + eval_order: EvaluationOrder, + identifier_id: IdentifierId, +) -> bool { + env.identifiers[identifier_id.0 as usize] + .mutable_range + .contains(eval_order) +} + +pub fn validate_no_derived_computations_in_effects_exp( + func: &HirFunction, + env: &Environment, +) -> Result { + let identifiers = &env.identifiers; + + let mut context = ValidationContext { + functions: HashMap::new(), + candidate_dependencies: HashMap::new(), + derivation_cache: DerivationCache::new(), + effects_cache: HashMap::new(), + set_state_loads: HashMap::new(), + set_state_usages: HashMap::new(), + }; + + // Initialize derivation cache based on function type + if func.fn_type == ReactFunctionType::Hook { + for param in &func.params { + if let ParamPattern::Place(place) = param { + let name = identifiers[place.identifier.0 as usize].name.clone(); + context.derivation_cache.cache.insert( + place.identifier, + DerivationMetadata { + place_identifier: place.identifier, + place_name: name, + source_ids: indexmap::IndexSet::new(), + type_of_value: TypeOfValue::FromProps, + is_state_source: true, + }, + ); + } + } + } else if func.fn_type == ReactFunctionType::Component { + if let Some(param) = func.params.first() { + if let ParamPattern::Place(place) = param { + let name = identifiers[place.identifier.0 as usize].name.clone(); + context.derivation_cache.cache.insert( + place.identifier, + DerivationMetadata { + place_identifier: place.identifier, + place_name: name, + source_ids: indexmap::IndexSet::new(), + type_of_value: TypeOfValue::FromProps, + is_state_source: true, + }, + ); + } + } + } + + // Fixpoint iteration + let mut is_first_pass = true; + let mut iteration_count = 0; + loop { + context.derivation_cache.take_snapshot(); + + for (_block_id, block) in &func.body.blocks { + record_phi_derivations(block, &mut context, env); + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + record_instruction_derivations(instr, &mut context, is_first_pass, func, env)?; + } + } + + context.derivation_cache.check_for_changes(); + is_first_pass = false; + iteration_count += 1; + assert!( + iteration_count < MAX_FIXPOINT_ITERATIONS, + "[ValidateNoDerivedComputationsInEffects] Fixpoint iteration failed to converge." + ); + + if !context.derivation_cache.snapshot() { + break; + } + } + + // Validate all effect sites + let mut errors = CompilerError::new(); + let effects_cache: Vec<(IdentifierId, FunctionId, Vec)> = context + .effects_cache + .iter() + .map(|(k, v)| (*k, v.effect_func_id, v.dep_elements.clone())) + .collect(); + + for (_key, effect_func_id, dep_elements) in &effects_cache { + validate_effect( + *effect_func_id, + dep_elements, + &mut context, + func, + env, + &mut errors, + ); + } + + Ok(errors) +} + +fn record_phi_derivations( + block: &react_compiler_hir::BasicBlock, + context: &mut ValidationContext, + env: &Environment, +) { + let identifiers = &env.identifiers; + for phi in &block.phis { + let mut type_of_value = TypeOfValue::Ignored; + let mut source_ids: indexmap::IndexSet = indexmap::IndexSet::new(); + + for (_block_id, operand) in &phi.operands { + if let Some(operand_metadata) = context.derivation_cache.cache.get(&operand.identifier) + { + type_of_value = join_value(type_of_value, operand_metadata.type_of_value); + source_ids.insert(operand.identifier); + } + } + + if type_of_value != TypeOfValue::Ignored { + let name = identifiers[phi.place.identifier.0 as usize].name.clone(); + context.derivation_cache.add_derivation_entry( + phi.place.identifier, + name, + source_ids, + type_of_value, + false, + ); + } + } +} + +fn record_instruction_derivations( + instr: &react_compiler_hir::Instruction, + context: &mut ValidationContext, + is_first_pass: bool, + _outer_func: &HirFunction, + env: &Environment, +) -> Result<(), CompilerDiagnostic> { + let identifiers = &env.identifiers; + let types = &env.types; + let functions = &env.functions; + let lvalue_id = instr.lvalue.identifier; + + // maybeRecordSetState + maybe_record_set_state_for_instr( + instr, + env, + &mut context.set_state_loads, + &mut context.set_state_usages, + ); + + let mut type_of_value = TypeOfValue::Ignored; + let is_source = false; + let mut sources: indexmap::IndexSet = indexmap::IndexSet::new(); + + match &instr.value { + InstructionValue::FunctionExpression { lowered_func, .. } => { + context.functions.insert(lvalue_id, lowered_func.func); + // Recurse into the inner function + let inner_func = &functions[lowered_func.func.0 as usize]; + for (_block_id, block) in &inner_func.body.blocks { + record_phi_derivations(block, context, env); + for &inner_instr_id in &block.instructions { + let inner_instr = &inner_func.instructions[inner_instr_id.0 as usize]; + record_instruction_derivations( + inner_instr, + context, + is_first_pass, + inner_func, + env, + )?; + } + } + } + InstructionValue::CallExpression { callee, args, .. } => { + let callee_type = &types[identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if is_use_effect_hook_type(callee_type) && args.len() == 2 { + if let ( + react_compiler_hir::PlaceOrSpread::Place(arg0), + react_compiler_hir::PlaceOrSpread::Place(arg1), + ) = (&args[0], &args[1]) + { + let effect_function = context.functions.get(&arg0.identifier).copied(); + let deps = context + .candidate_dependencies + .get(&arg1.identifier) + .cloned(); + if let (Some(effect_func_id), Some(dep_elements)) = (effect_function, deps) { + context.effects_cache.insert( + arg0.identifier, + EffectMetadata { + effect_func_id, + dep_elements, + }, + ); + } + } + } + + // Check if lvalue is useState type + let lvalue_type = &types[identifiers[lvalue_id.0 as usize].type_.0 as usize]; + if is_use_state_type(lvalue_type) { + let name = identifiers[lvalue_id.0 as usize].name.clone(); + context.derivation_cache.add_derivation_entry( + lvalue_id, + name, + indexmap::IndexSet::new(), + TypeOfValue::FromState, + true, + ); + return Ok(()); + } + } + InstructionValue::MethodCall { property, args, .. } => { + let prop_type = &types[identifiers[property.identifier.0 as usize].type_.0 as usize]; + if is_use_effect_hook_type(prop_type) && args.len() == 2 { + if let ( + react_compiler_hir::PlaceOrSpread::Place(arg0), + react_compiler_hir::PlaceOrSpread::Place(arg1), + ) = (&args[0], &args[1]) + { + let effect_function = context.functions.get(&arg0.identifier).copied(); + let deps = context + .candidate_dependencies + .get(&arg1.identifier) + .cloned(); + if let (Some(effect_func_id), Some(dep_elements)) = (effect_function, deps) { + context.effects_cache.insert( + arg0.identifier, + EffectMetadata { + effect_func_id, + dep_elements, + }, + ); + } + } + } + + // Check if lvalue is useState type + let lvalue_type = &types[identifiers[lvalue_id.0 as usize].type_.0 as usize]; + if is_use_state_type(lvalue_type) { + let name = identifiers[lvalue_id.0 as usize].name.clone(); + context.derivation_cache.add_derivation_entry( + lvalue_id, + name, + indexmap::IndexSet::new(), + TypeOfValue::FromState, + true, + ); + return Ok(()); + } + } + InstructionValue::ArrayExpression { elements, .. } => { + let dep_elements: Vec = elements + .iter() + .filter_map(|el| match el { + ArrayElement::Place(p) => Some(DepElement { + identifier: p.identifier, + loc: p.loc, + }), + _ => None, + }) + .collect(); + context + .candidate_dependencies + .insert(lvalue_id, dep_elements); + } + _ => {} + } + + // Collect operand derivations + for (operand_id, operand_loc) in each_instruction_operand(instr, env) { + // Track setState usages + if context.set_state_loads.contains_key(&operand_id) { + let root = + get_root_set_state(operand_id, &context.set_state_loads, &mut HashSet::new()); + if let Some(root_id) = root { + if let Some(usages) = context.set_state_usages.get_mut(&root_id) { + usages.insert(LocKey::from_loc(&operand_loc)); + } + } + } + + if let Some(operand_metadata) = context.derivation_cache.cache.get(&operand_id) { + type_of_value = join_value(type_of_value, operand_metadata.type_of_value); + sources.insert(operand_id); + } + } + + if type_of_value == TypeOfValue::Ignored { + return Ok(()); + } + + // Record derivation for ALL lvalue places (including destructured variables) + for &lv_id in &each_instruction_lvalue_ids(instr) { + let name = identifiers[lv_id.0 as usize].name.clone(); + context.derivation_cache.add_derivation_entry( + lv_id, + name, + sources.clone(), + type_of_value, + is_source, + ); + } + + if matches!(&instr.value, InstructionValue::FunctionExpression { .. }) { + // Don't record mutation effects for FunctionExpressions + return Ok(()); + } + + // Handle mutable operands + for operand in each_instruction_operand_with_effect(instr, env) { + if operand.effect.is_mutable() { + if is_mutable_at(env, instr.id, operand.id) { + if let Some(existing) = context.derivation_cache.cache.get_mut(&operand.id) { + existing.type_of_value = join_value(type_of_value, existing.type_of_value); + } else { + let name = identifiers[operand.id.0 as usize].name.clone(); + context.derivation_cache.add_derivation_entry( + operand.id, + name, + sources.clone(), + type_of_value, + false, + ); + } + } + } else if matches!(operand.effect, Effect::Unknown) { + return Err(CompilerDiagnostic::new( + ErrorCategory::Invariant, + "Unexpected unknown effect", + None, + )); + } + // Freeze | Read => no-op + } + Ok(()) +} + +struct OperandWithEffect { + id: IdentifierId, + effect: Effect, +} + +/// Collects operand (IdentifierId, loc) pairs from an instruction. +/// Thin wrapper around canonical `each_instruction_operand` that maps Places to +/// (id, loc) pairs. +fn each_instruction_operand( + instr: &react_compiler_hir::Instruction, + env: &Environment, +) -> Vec<(IdentifierId, Option)> { + canonical_each_instruction_operand(instr, env) + .into_iter() + .map(|place| (place.identifier, place.loc)) + .collect() +} + +/// Collects operands with their effects. +/// Thin wrapper around canonical `each_instruction_operand` that maps Places to +/// OperandWithEffect. +fn each_instruction_operand_with_effect( + instr: &react_compiler_hir::Instruction, + env: &Environment, +) -> Vec { + canonical_each_instruction_operand(instr, env) + .into_iter() + .map(|place| OperandWithEffect { + id: place.identifier, + effect: place.effect, + }) + .collect() +} + +// ============================================================================= +// Tree building and rendering (for error messages) +// ============================================================================= + +struct TreeNode { + name: String, + type_of_value: TypeOfValue, + is_source: bool, + children: Vec, +} + +fn build_tree_node( + source_id: IdentifierId, + context: &ValidationContext, + visited: &HashSet, +) -> Vec { + let source_metadata = match context.derivation_cache.cache.get(&source_id) { + Some(m) => m, + None => return Vec::new(), + }; + + if source_metadata.is_state_source { + if let Some(IdentifierName::Named(name)) = &source_metadata.place_name { + return vec![TreeNode { + name: name.clone(), + type_of_value: source_metadata.type_of_value, + is_source: true, + children: Vec::new(), + }]; + } + } + + let mut children: Vec = Vec::new(); + let mut named_siblings: indexmap::IndexSet = indexmap::IndexSet::new(); + + for child_id in &source_metadata.source_ids { + assert_ne!( + *child_id, source_id, + "Unexpected self-reference: a value should not have itself as a source" + ); + + let mut new_visited = visited.clone(); + if let Some(IdentifierName::Named(name)) = &source_metadata.place_name { + new_visited.insert(name.clone()); + } + + let child_nodes = build_tree_node(*child_id, context, &new_visited); + for child_node in child_nodes { + if !named_siblings.contains(&child_node.name) { + named_siblings.insert(child_node.name.clone()); + children.push(child_node); + } + } + } + + if let Some(IdentifierName::Named(name)) = &source_metadata.place_name { + if !visited.contains(name) { + return vec![TreeNode { + name: name.clone(), + type_of_value: source_metadata.type_of_value, + is_source: source_metadata.is_state_source, + children, + }]; + } + } + + children +} + +fn render_tree( + node: &TreeNode, + indent: &str, + is_last: bool, + props_set: &mut indexmap::IndexSet, + state_set: &mut indexmap::IndexSet, +) -> String { + let prefix = format!( + "{}{}", + indent, + if is_last { + "\u{2514}\u{2500}\u{2500} " + } else { + "\u{251c}\u{2500}\u{2500} " + } + ); + let child_indent = format!("{}{}", indent, if is_last { " " } else { "\u{2502} " }); + + let mut result = format!("{}{}", prefix, node.name); + + if node.is_source { + let type_label = match node.type_of_value { + TypeOfValue::FromProps => { + props_set.insert(node.name.clone()); + "Prop" + } + TypeOfValue::FromState => { + state_set.insert(node.name.clone()); + "State" + } + _ => { + props_set.insert(node.name.clone()); + state_set.insert(node.name.clone()); + "Prop and State" + } + }; + result += &format!(" ({})", type_label); + } + + if !node.children.is_empty() { + result += "\n"; + for (index, child) in node.children.iter().enumerate() { + let is_last_child = index == node.children.len() - 1; + result += &render_tree(child, &child_indent, is_last_child, props_set, state_set); + if index < node.children.len() - 1 { + result += "\n"; + } + } + } + + result +} + +fn get_fn_local_deps( + func_id: Option, + env: &Environment, +) -> Option> { + let func_id = func_id?; + let inner = &env.functions[func_id.0 as usize]; + let mut deps: HashSet = HashSet::new(); + + for (_block_id, block) in &inner.body.blocks { + for &instr_id in &block.instructions { + let instr = &inner.instructions[instr_id.0 as usize]; + if let InstructionValue::LoadLocal { place, .. } = &instr.value { + deps.insert(place.identifier); + } + } + } + + Some(deps) +} + +fn validate_effect( + effect_func_id: FunctionId, + dependencies: &[DepElement], + context: &mut ValidationContext, + _outer_func: &HirFunction, + env: &Environment, + errors: &mut CompilerError, +) { + let identifiers = &env.identifiers; + let types = &env.types; + let functions = &env.functions; + let effect_function = &functions[effect_func_id.0 as usize]; + let mut seen_blocks: HashSet = HashSet::new(); + + struct DerivedSetStateCall { + callee_loc: Option, + callee_id: IdentifierId, + callee_identifier_name: Option, + source_ids: indexmap::IndexSet, + } + + let mut effect_derived_set_state_calls: Vec = Vec::new(); + let mut effect_set_state_usages: HashMap> = HashMap::new(); + + // Consider setStates in the effect's dependency array as being part of + // effectSetStateUsages + for dep in dependencies { + let root = get_root_set_state( + dep.identifier, + &context.set_state_loads, + &mut HashSet::new(), + ); + if let Some(root_id) = root { + let mut set = HashSet::new(); + set.insert(LocKey::from_loc(&dep.loc)); + effect_set_state_usages.insert(root_id, set); + } + } + + let mut cleanup_function_deps: Option> = None; + let mut globals: HashSet = HashSet::new(); + + for (_block_id, block) in &effect_function.body.blocks { + // Check for return -> cleanup function + if let react_compiler_hir::Terminal::Return { + value, + return_variant: ReturnVariant::Explicit, + .. + } = &block.terminal + { + let func_id = context.functions.get(&value.identifier).copied(); + cleanup_function_deps = get_fn_local_deps(func_id, env); + } + + // Skip if block has a back edge (pred not yet seen) + let has_back_edge = block.preds.iter().any(|pred| !seen_blocks.contains(pred)); + if has_back_edge { + return; + } + + for &instr_id in &block.instructions { + let instr = &effect_function.instructions[instr_id.0 as usize]; + + // Early return if any instruction derives from a ref + let lvalue_type = + &types[identifiers[instr.lvalue.identifier.0 as usize].type_.0 as usize]; + if is_use_ref_type(lvalue_type) { + return; + } + + // maybeRecordSetState for effect instructions + maybe_record_set_state_for_instr( + instr, + env, + &mut context.set_state_loads, + &mut effect_set_state_usages, + ); + + // Track setState usages for operands + for (operand_id, operand_loc) in each_instruction_operand(instr, env) { + if context.set_state_loads.contains_key(&operand_id) { + let root = get_root_set_state( + operand_id, + &context.set_state_loads, + &mut HashSet::new(), + ); + if let Some(root_id) = root { + if let Some(usages) = effect_set_state_usages.get_mut(&root_id) { + usages.insert(LocKey::from_loc(&operand_loc)); + } + } + } + } + + match &instr.value { + InstructionValue::CallExpression { callee, args, .. } => { + let callee_type = + &types[identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if is_set_state_type(callee_type) && args.len() == 1 { + if let react_compiler_hir::PlaceOrSpread::Place(arg0) = &args[0] { + let callee_metadata = + context.derivation_cache.cache.get(&callee.identifier); + + // If the setState comes from a source other than local state, skip + if let Some(cm) = callee_metadata { + if cm.type_of_value != TypeOfValue::FromState { + continue; + } + } else { + continue; + } + + let arg_metadata = context.derivation_cache.cache.get(&arg0.identifier); + if let Some(am) = arg_metadata { + // Get the user-visible identifier name, matching Babel's + // loc.identifierName. Falls back to extracting from source code. + let callee_ident_name = get_identifier_name_with_loc( + callee.identifier, + identifiers, + &callee.loc, + env.code.as_deref(), + ); + effect_derived_set_state_calls.push(DerivedSetStateCall { + callee_loc: callee.loc, + callee_id: callee.identifier, + callee_identifier_name: callee_ident_name, + source_ids: am.source_ids.clone(), + }); + } + } + } else { + // Check if callee is from props/propsAndState -> bail + let callee_metadata = + context.derivation_cache.cache.get(&callee.identifier); + if let Some(cm) = callee_metadata { + if cm.type_of_value == TypeOfValue::FromProps + || cm.type_of_value == TypeOfValue::FromPropsAndState + { + return; + } + } + + if globals.contains(&callee.identifier) { + return; + } + } + } + InstructionValue::LoadGlobal { .. } => { + globals.insert(instr.lvalue.identifier); + for (operand_id, _) in each_instruction_operand(instr, env) { + globals.insert(operand_id); + } + } + _ => {} + } + } + seen_blocks.insert(block.id); + } + + // Emit errors for derived setState calls + for derived in &effect_derived_set_state_calls { + let root_set_state_call = get_root_set_state( + derived.callee_id, + &context.set_state_loads, + &mut HashSet::new(), + ); + if let Some(root_id) = root_set_state_call { + let effect_usage_count = effect_set_state_usages + .get(&root_id) + .map(|s| s.len()) + .unwrap_or(0); + let total_usage_count = context + .set_state_usages + .get(&root_id) + .map(|s| s.len()) + .unwrap_or(0); + if effect_set_state_usages.contains_key(&root_id) + && context.set_state_usages.contains_key(&root_id) + && effect_usage_count == total_usage_count - 1 + { + let mut props_set: indexmap::IndexSet = indexmap::IndexSet::new(); + let mut state_set: indexmap::IndexSet = indexmap::IndexSet::new(); + + let mut root_nodes_map: indexmap::IndexMap = + indexmap::IndexMap::new(); + for id in &derived.source_ids { + let nodes = build_tree_node(*id, context, &HashSet::new()); + for node in nodes { + if !root_nodes_map.contains_key(&node.name) { + root_nodes_map.insert(node.name.clone(), node); + } + } + } + let root_nodes: Vec<&TreeNode> = root_nodes_map.values().collect(); + + let trees: Vec = root_nodes + .iter() + .enumerate() + .map(|(index, node)| { + render_tree( + node, + "", + index == root_nodes.len() - 1, + &mut props_set, + &mut state_set, + ) + }) + .collect(); + + // Check cleanup function dependencies + let should_skip = if let Some(ref cleanup_deps) = cleanup_function_deps { + derived + .source_ids + .iter() + .any(|dep| cleanup_deps.contains(dep)) + } else { + false + }; + if should_skip { + return; + } + + let mut root_sources = String::new(); + if !props_set.is_empty() { + let props_list: Vec<&str> = props_set.iter().map(|s| s.as_str()).collect(); + root_sources += &format!("Props: [{}]", props_list.join(", ")); + } + if !state_set.is_empty() { + if !root_sources.is_empty() { + root_sources += "\n"; + } + let state_list: Vec<&str> = state_set.iter().map(|s| s.as_str()).collect(); + root_sources += &format!("State: [{}]", state_list.join(", ")); + } + + let description = format!( + "Using an effect triggers an additional render which can hurt performance and user experience, potentially briefly showing stale values to the user\n\n\ + This setState call is setting a derived value that depends on the following reactive sources:\n\n\ + {}\n\n\ + Data Flow Tree:\n\ + {}\n\n\ + See: https://react.dev/learn/you-might-not-need-an-effect#updating-state-based-on-props-or-state", + root_sources, + trees.join("\n"), + ); + + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::EffectDerivationsOfState, + "You might not need an effect. Derive values in render, not effects.", + Some(description), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: derived.callee_loc, + message: Some( + "This should be computed during render, not in an effect".to_string(), + ), + identifier_name: derived.callee_identifier_name.clone(), + }), + ); + } + } + } +} + +// ============================================================================= +// Non-exp version: ValidateNoDerivedComputationsInEffects +// Port of ValidateNoDerivedComputationsInEffects.ts +// ============================================================================= + +/// Non-experimental version of the derived-computations-in-effects validation. +/// Records errors directly on the Environment (matching TS `env.recordError()` +/// behavior). +pub fn validate_no_derived_computations_in_effects( + func: &HirFunction, + env: &mut Environment, +) -> Result<(), CompilerError> { + // Phase 1: Collect effect call sites (func_id + resolved deps). + // Done with only immutable borrows of env fields. + let effects_to_validate: Vec<(FunctionId, Vec)> = { + let ids = &env.identifiers; + let tys = &env.types; + let mut candidate_deps: HashMap> = HashMap::new(); + let mut functions_map: HashMap = HashMap::new(); + let mut locals_map: HashMap = HashMap::new(); + let mut result = Vec::new(); + + for (_, block) in &func.body.blocks { + for &iid in &block.instructions { + let instr = &func.instructions[iid.0 as usize]; + match &instr.value { + InstructionValue::LoadLocal { place, .. } => { + locals_map.insert(instr.lvalue.identifier, place.identifier); + } + InstructionValue::ArrayExpression { elements, .. } => { + let elem_ids: Vec = elements + .iter() + .filter_map(|e| match e { + ArrayElement::Place(p) => Some(p.identifier), + _ => None, + }) + .collect(); + if elem_ids.len() == elements.len() { + candidate_deps.insert(instr.lvalue.identifier, elem_ids); + } + } + InstructionValue::FunctionExpression { lowered_func, .. } => { + functions_map.insert(instr.lvalue.identifier, lowered_func.func); + } + InstructionValue::CallExpression { callee, args, .. } => { + let callee_ty = &tys[ids[callee.identifier.0 as usize].type_.0 as usize]; + if is_use_effect_hook_type(callee_ty) && args.len() == 2 { + if let (PlaceOrSpread::Place(arg0), PlaceOrSpread::Place(arg1)) = + (&args[0], &args[1]) + { + if let (Some(&func_id), Some(dep_elements)) = ( + functions_map.get(&arg0.identifier), + candidate_deps.get(&arg1.identifier), + ) { + if !dep_elements.is_empty() { + let resolved: Vec = dep_elements + .iter() + .map(|d| locals_map.get(d).copied().unwrap_or(*d)) + .collect(); + result.push((func_id, resolved)); + } + } + } + } + } + InstructionValue::MethodCall { property, args, .. } => { + let callee_ty = &tys[ids[property.identifier.0 as usize].type_.0 as usize]; + if is_use_effect_hook_type(callee_ty) && args.len() == 2 { + if let (PlaceOrSpread::Place(arg0), PlaceOrSpread::Place(arg1)) = + (&args[0], &args[1]) + { + if let (Some(&func_id), Some(dep_elements)) = ( + functions_map.get(&arg0.identifier), + candidate_deps.get(&arg1.identifier), + ) { + if !dep_elements.is_empty() { + let resolved: Vec = dep_elements + .iter() + .map(|d| locals_map.get(d).copied().unwrap_or(*d)) + .collect(); + result.push((func_id, resolved)); + } + } + } + } + } + _ => {} + } + } + } + result + }; + + // Phase 2: Validate each collected effect and record error details. + // Uses ErrorDetail (flat loc format) to match TS behavior where + // env.recordError(new CompilerErrorDetail({...})) is used. + for (func_id, resolved_deps) in effects_to_validate { + let details = validate_effect_non_exp( + &env.functions[func_id.0 as usize], + &resolved_deps, + &env.identifiers, + &env.types, + ); + for detail in details { + env.record_error(detail)?; + } + } + Ok(()) +} + +fn validate_effect_non_exp( + effect_func: &HirFunction, + effect_deps: &[IdentifierId], + ids: &[Identifier], + tys: &[Type], +) -> Vec { + // Check that the effect function only captures effect deps and setState + for ctx in &effect_func.context { + let ctx_ty = &tys[ids[ctx.identifier.0 as usize].type_.0 as usize]; + if is_set_state_type(ctx_ty) { + continue; + } else if effect_deps.iter().any(|d| *d == ctx.identifier) { + continue; + } else { + return Vec::new(); + } + } + + // Check that all effect deps are actually used in the function + for dep in effect_deps { + if !effect_func.context.iter().any(|c| c.identifier == *dep) { + return Vec::new(); + } + } + + let mut seen_blocks: HashSet = HashSet::new(); + let mut dep_values: HashMap> = HashMap::new(); + for dep in effect_deps { + dep_values.insert(*dep, vec![*dep]); + } + + let mut set_state_locs: Vec = Vec::new(); + + for (_, block) in &effect_func.body.blocks { + for &pred in &block.preds { + if !seen_blocks.contains(&pred) { + return Vec::new(); + } + } + + for phi in &block.phis { + let mut aggregate: HashSet = HashSet::new(); + for operand in phi.operands.values() { + if let Some(deps) = dep_values.get(&operand.identifier) { + for d in deps { + aggregate.insert(*d); + } + } + } + if !aggregate.is_empty() { + dep_values.insert(phi.place.identifier, aggregate.into_iter().collect()); + } + } + + for &iid in &block.instructions { + let instr = &effect_func.instructions[iid.0 as usize]; + match &instr.value { + InstructionValue::Primitive { .. } + | InstructionValue::JSXText { .. } + | InstructionValue::LoadGlobal { .. } => {} + InstructionValue::LoadLocal { place, .. } => { + if let Some(deps) = dep_values.get(&place.identifier) { + dep_values.insert(instr.lvalue.identifier, deps.clone()); + } + } + InstructionValue::ComputedLoad { .. } + | InstructionValue::PropertyLoad { .. } + | InstructionValue::BinaryExpression { .. } + | InstructionValue::TemplateLiteral { .. } + | InstructionValue::CallExpression { .. } + | InstructionValue::MethodCall { .. } => { + let mut aggregate: HashSet = HashSet::new(); + for operand in non_exp_value_operands(&instr.value) { + if let Some(deps) = dep_values.get(&operand) { + for d in deps { + aggregate.insert(*d); + } + } + } + if !aggregate.is_empty() { + dep_values.insert(instr.lvalue.identifier, aggregate.into_iter().collect()); + } + + if let InstructionValue::CallExpression { callee, args, .. } = &instr.value { + let callee_ty = &tys[ids[callee.identifier.0 as usize].type_.0 as usize]; + if is_set_state_type(callee_ty) && args.len() == 1 { + if let PlaceOrSpread::Place(arg) = &args[0] { + if let Some(deps) = dep_values.get(&arg.identifier) { + let dep_set: HashSet<_> = deps.iter().collect(); + if dep_set.len() == effect_deps.len() { + if let Some(loc) = callee.loc { + set_state_locs.push(loc); + } + } else { + return Vec::new(); + } + } else { + return Vec::new(); + } + } + } + } + } + _ => { + return Vec::new(); + } + } + } + + match &block.terminal { + react_compiler_hir::Terminal::Return { value, .. } + | react_compiler_hir::Terminal::Throw { value, .. } => { + if dep_values.contains_key(&value.identifier) { + return Vec::new(); + } + } + react_compiler_hir::Terminal::If { test, .. } + | react_compiler_hir::Terminal::Branch { test, .. } => { + if dep_values.contains_key(&test.identifier) { + return Vec::new(); + } + } + react_compiler_hir::Terminal::Switch { test, .. } => { + if dep_values.contains_key(&test.identifier) { + return Vec::new(); + } + } + _ => {} + } + + seen_blocks.insert(block.id); + } + + set_state_locs + .into_iter() + .map(|loc| { + CompilerErrorDetail { + category: ErrorCategory::EffectDerivationsOfState, + reason: "Values derived from props and state should be calculated during render, not in an effect. (https://react.dev/learn/you-might-not-need-an-effect#updating-state-based-on-props-or-state)".to_string(), + description: None, + loc: Some(loc), + suggestions: None, + } + }) + .collect() +} + +/// Collects operand IdentifierIds for a subset of instruction variants used +/// by `validate_effect_non_exp`. +/// +/// NOTE: This intentionally does NOT use the canonical +/// `each_instruction_value_operand` because: (1) `validate_effect_non_exp` only +/// matches specific variants (ComputedLoad, PropertyLoad, BinaryExpression, +/// TemplateLiteral, CallExpression, MethodCall), so +/// FunctionExpression/ObjectMethod context handling is unnecessary; and (2) the +/// caller does not have access to `env` which the canonical function requires +/// for resolving function expression context captures. +fn non_exp_value_operands(value: &InstructionValue) -> Vec { + match value { + InstructionValue::ComputedLoad { + object, property, .. + } => { + vec![object.identifier, property.identifier] + } + InstructionValue::PropertyLoad { object, .. } => vec![object.identifier], + InstructionValue::BinaryExpression { left, right, .. } => { + vec![left.identifier, right.identifier] + } + InstructionValue::TemplateLiteral { subexprs, .. } => { + subexprs.iter().map(|s| s.identifier).collect() + } + InstructionValue::CallExpression { callee, args, .. } => { + let mut op_ids = vec![callee.identifier]; + for a in args { + match a { + PlaceOrSpread::Place(p) => op_ids.push(p.identifier), + PlaceOrSpread::Spread(s) => op_ids.push(s.place.identifier), + } + } + op_ids + } + InstructionValue::MethodCall { + receiver, + property, + args, + .. + } => { + let mut op_ids = vec![receiver.identifier, property.identifier]; + for a in args { + match a { + PlaceOrSpread::Place(p) => op_ids.push(p.identifier), + PlaceOrSpread::Spread(s) => op_ids.push(s.place.identifier), + } + } + op_ids + } + _ => Vec::new(), + } +} diff --git a/crates/react_compiler_validation/src/validate_no_freezing_known_mutable_functions.rs b/crates/react_compiler_validation/src/validate_no_freezing_known_mutable_functions.rs new file mode 100644 index 000000000000..c146e51ffbf2 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_freezing_known_mutable_functions.rs @@ -0,0 +1,225 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, + visitors::{each_instruction_value_operand, each_terminal_operand}, + AliasingEffect, Effect, HirFunction, Identifier, IdentifierId, IdentifierName, + InstructionValue, Place, Type, +}; + +/// Information about a known mutation effect: which identifier is mutated, and +/// the source location of the mutation. +#[derive(Debug, Clone)] +struct MutationInfo { + value_identifier: IdentifierId, + value_loc: Option, +} + +/// Validates that functions with known mutations (ie due to types) cannot be +/// passed where a frozen value is expected. +/// +/// Because a function that mutates a captured variable is equivalent to a +/// mutable value, and the receiver has no way to avoid calling the function, +/// this pass detects functions with *known* mutations (Mutate or +/// MutateTransitive, not conditional) that are passed where a frozen value is +/// expected and reports an error. +pub fn validate_no_freezing_known_mutable_functions(func: &HirFunction, env: &mut Environment) { + let diagnostics = check_no_freezing_known_mutable_functions( + func, + &env.identifiers, + &env.types, + &env.functions, + env, + ); + for diagnostic in diagnostics { + env.record_diagnostic(diagnostic); + } +} + +fn check_no_freezing_known_mutable_functions( + func: &HirFunction, + identifiers: &[Identifier], + types: &[Type], + functions: &[HirFunction], + env: &Environment, +) -> Vec { + // Maps an identifier to the mutation effect that makes it "known mutable" + let mut context_mutation_effects: HashMap = HashMap::new(); + let mut diagnostics: Vec = Vec::new(); + + for (_block_id, block) in &func.body.blocks { + for &instruction_id in &block.instructions { + let instr = &func.instructions[instruction_id.0 as usize]; + + match &instr.value { + InstructionValue::LoadLocal { place, .. } => { + // Propagate known mutation from the loaded place to the lvalue + if let Some(mutation_info) = context_mutation_effects.get(&place.identifier) { + context_mutation_effects + .insert(instr.lvalue.identifier, mutation_info.clone()); + } + } + + InstructionValue::StoreLocal { lvalue, value, .. } => { + // Propagate known mutation from the stored value to both the + // instruction lvalue and the StoreLocal's target lvalue + if let Some(mutation_info) = context_mutation_effects.get(&value.identifier) { + let mutation_info = mutation_info.clone(); + context_mutation_effects + .insert(instr.lvalue.identifier, mutation_info.clone()); + context_mutation_effects.insert(lvalue.place.identifier, mutation_info); + } + } + + InstructionValue::FunctionExpression { lowered_func, .. } => { + let inner_function = &functions[lowered_func.func.0 as usize]; + if let Some(ref aliasing_effects) = inner_function.aliasing_effects { + let context_ids: HashSet = inner_function + .context + .iter() + .map(|place| place.identifier) + .collect(); + + 'effects: for effect in aliasing_effects { + match effect { + AliasingEffect::Mutate { value, .. } + | AliasingEffect::MutateTransitive { value, .. } => { + // If the mutated value is already known-mutable, propagate + if let Some(known_mutation) = + context_mutation_effects.get(&value.identifier) + { + context_mutation_effects.insert( + instr.lvalue.identifier, + known_mutation.clone(), + ); + } else if context_ids.contains(&value.identifier) + && !is_ref_or_ref_like_mutable_type( + value.identifier, + identifiers, + types, + ) + { + // New known mutation of a context variable + context_mutation_effects.insert( + instr.lvalue.identifier, + MutationInfo { + value_identifier: value.identifier, + value_loc: value.loc, + }, + ); + break 'effects; + } + } + + AliasingEffect::MutateConditionally { value, .. } + | AliasingEffect::MutateTransitiveConditionally { value, .. } => { + // Only propagate existing known mutations for conditional + // effects + if let Some(known_mutation) = + context_mutation_effects.get(&value.identifier) + { + context_mutation_effects.insert( + instr.lvalue.identifier, + known_mutation.clone(), + ); + } + } + + _ => {} + } + } + } + } + + _ => { + // For all other instruction kinds, check operands for freeze violations + for operand in each_instruction_value_operand(&instr.value, env) { + check_operand_for_freeze_violation( + &operand, + &context_mutation_effects, + identifiers, + &mut diagnostics, + ); + } + } + } + } + + // Also check terminal operands + for operand in each_terminal_operand(&block.terminal) { + check_operand_for_freeze_violation( + &operand, + &context_mutation_effects, + identifiers, + &mut diagnostics, + ); + } + } + + diagnostics +} + +/// If an operand with Effect::Freeze is a known-mutable function, emit a +/// diagnostic. +fn check_operand_for_freeze_violation( + operand: &Place, + context_mutation_effects: &HashMap, + identifiers: &[Identifier], + diagnostics: &mut Vec, +) { + if operand.effect == Effect::Freeze { + if let Some(mutation_info) = context_mutation_effects.get(&operand.identifier) { + let identifier = &identifiers[mutation_info.value_identifier.0 as usize]; + let variable_name = match &identifier.name { + Some(IdentifierName::Named(name)) => format!("`{}`", name), + _ => "a local variable".to_string(), + }; + + diagnostics.push( + CompilerDiagnostic::new( + ErrorCategory::Immutability, + "Cannot modify local variables after render completes", + Some(format!( + "This argument is a function which may reassign or mutate {} after \ + render, which can cause inconsistent behavior on subsequent renders. \ + Consider using state instead", + variable_name + )), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: operand.loc, + message: Some(format!( + "This function may (indirectly) reassign or modify {} after render", + variable_name + )), + identifier_name: None, + }) + .with_detail(CompilerDiagnosticDetail::Error { + loc: mutation_info.value_loc, + message: Some(format!("This modifies {}", variable_name)), + identifier_name: None, + }), + ); + } + } +} + +/// Check if an identifier's type is a ref or ref-like mutable type. +fn is_ref_or_ref_like_mutable_type( + identifier_id: IdentifierId, + identifiers: &[Identifier], + types: &[Type], +) -> bool { + let identifier = &identifiers[identifier_id.0 as usize]; + react_compiler_hir::is_ref_or_ref_like_mutable_type(&types[identifier.type_.0 as usize]) +} diff --git a/crates/react_compiler_validation/src/validate_no_jsx_in_try_statement.rs b/crates/react_compiler_validation/src/validate_no_jsx_in_try_statement.rs new file mode 100644 index 000000000000..e977269241a7 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_jsx_in_try_statement.rs @@ -0,0 +1,66 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates against constructing JSX within try/catch blocks. +//! +//! Developers may not be aware of error boundaries and lazy evaluation of JSX, +//! leading them to use patterns such as `let el; try { el = } +//! catch { ... }` to attempt to catch rendering errors. Such code will fail to +//! catch errors in rendering, but developers may not realize this right away. +//! +//! This validation pass errors for JSX created within a try block. JSX is +//! allowed within a catch statement, unless that catch is itself nested inside +//! an outer try. +//! +//! Port of ValidateNoJSXInTryStatement.ts. + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, ErrorCategory, +}; +use react_compiler_hir::{BlockId, HirFunction, InstructionValue, Terminal}; + +pub fn validate_no_jsx_in_try_statement(func: &HirFunction) -> CompilerError { + let mut active_try_blocks: Vec = Vec::new(); + let mut error = CompilerError::new(); + + for (_block_id, block) in &func.body.blocks { + // Remove completed try blocks (retainWhere equivalent) + active_try_blocks.retain(|id| *id != block.id); + + if !active_try_blocks.is_empty() { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::JsxExpression { loc, .. } + | InstructionValue::JsxFragment { loc, .. } => { + error.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::ErrorBoundaries, + "Avoid constructing JSX within try/catch", + Some( + "React does not immediately render components when JSX is rendered, so any errors from this component will not be caught by the try/catch. To catch errors in rendering a given component, wrap that component in an error boundary. (https://react.dev/reference/react/Component#catching-rendering-errors-with-an-error-boundary)".to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: *loc, + message: Some( + "Avoid constructing JSX within try/catch".to_string(), + ), + identifier_name: None, + }), + ); + } + _ => {} + } + } + } + + if let Terminal::Try { handler, .. } = &block.terminal { + active_try_blocks.push(*handler); + } + } + + error +} diff --git a/crates/react_compiler_validation/src/validate_no_ref_access_in_render.rs b/crates/react_compiler_validation/src/validate_no_ref_access_in_render.rs new file mode 100644 index 000000000000..93aca39d5609 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_ref_access_in_render.rs @@ -0,0 +1,1201 @@ +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, + object_shape::HookKind, + visitors::{ + each_instruction_value_operand as canonical_each_instruction_value_operand, + each_pattern_operand, each_terminal_operand, + }, + AliasingEffect, BlockId, HirFunction, Identifier, IdentifierId, InstructionValue, Place, + PrimitiveValue, PropertyLiteral, Terminal, Type, UnaryOperator, +}; + +const ERROR_DESCRIPTION: &str = "React refs are values that are not needed for rendering. \ + Refs should only be accessed outside of render, such as in event handlers or effects. \ + Accessing a ref value (the `current` property) during render can cause your component \ + not to update as expected (https://react.dev/reference/react/useRef)"; + +// --- RefId --- + +type RefId = u32; + +static REF_ID_COUNTER: std::sync::atomic::AtomicU32 = std::sync::atomic::AtomicU32::new(0); + +fn next_ref_id() -> RefId { + REF_ID_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed) +} + +// --- RefAccessType / RefAccessRefType / RefFnType --- + +/// Corresponds to TS `RefAccessType` +#[derive(Debug, Clone, PartialEq)] +enum RefAccessType { + None, + Nullable, + Guard { + ref_id: RefId, + }, + Ref { + ref_id: RefId, + }, + RefValue { + loc: Option, + ref_id: Option, + }, + Structure { + value: Option>, + fn_type: Option, + }, +} + +/// Corresponds to TS `RefAccessRefType` — the subset of `RefAccessType` that +/// can appear inside `Structure.value` and be joined via +/// `join_ref_access_ref_types`. +#[derive(Debug, Clone, PartialEq)] +enum RefAccessRefType { + Ref { + ref_id: RefId, + }, + RefValue { + loc: Option, + ref_id: Option, + }, + Structure { + value: Option>, + fn_type: Option, + }, +} + +#[derive(Debug, Clone, PartialEq)] +struct RefFnType { + read_ref_effect: bool, + return_type: Box, +} + +impl RefAccessType { + /// Try to convert a `RefAccessType` to a `RefAccessRefType` (the + /// Ref/RefValue/Structure subset). + fn to_ref_type(&self) -> Option { + match self { + RefAccessType::Ref { ref_id } => Some(RefAccessRefType::Ref { ref_id: *ref_id }), + RefAccessType::RefValue { loc, ref_id } => Some(RefAccessRefType::RefValue { + loc: *loc, + ref_id: *ref_id, + }), + RefAccessType::Structure { value, fn_type } => Some(RefAccessRefType::Structure { + value: value.clone(), + fn_type: fn_type.clone(), + }), + _ => None, + } + } + + /// Convert a `RefAccessRefType` back to a `RefAccessType`. + fn from_ref_type(ref_type: &RefAccessRefType) -> Self { + match ref_type { + RefAccessRefType::Ref { ref_id } => RefAccessType::Ref { ref_id: *ref_id }, + RefAccessRefType::RefValue { loc, ref_id } => RefAccessType::RefValue { + loc: *loc, + ref_id: *ref_id, + }, + RefAccessRefType::Structure { value, fn_type } => RefAccessType::Structure { + value: value.clone(), + fn_type: fn_type.clone(), + }, + } + } +} + +// --- Join operations --- + +fn join_ref_access_ref_types(a: &RefAccessRefType, b: &RefAccessRefType) -> RefAccessRefType { + match (a, b) { + ( + RefAccessRefType::RefValue { ref_id: a_id, .. }, + RefAccessRefType::RefValue { ref_id: b_id, .. }, + ) => { + if a_id == b_id { + a.clone() + } else { + RefAccessRefType::RefValue { + loc: None, + ref_id: None, + } + } + } + (RefAccessRefType::RefValue { .. }, _) => RefAccessRefType::RefValue { + loc: None, + ref_id: None, + }, + (_, RefAccessRefType::RefValue { .. }) => RefAccessRefType::RefValue { + loc: None, + ref_id: None, + }, + (RefAccessRefType::Ref { ref_id: a_id }, RefAccessRefType::Ref { ref_id: b_id }) => { + if a_id == b_id { + a.clone() + } else { + RefAccessRefType::Ref { + ref_id: next_ref_id(), + } + } + } + (RefAccessRefType::Ref { .. }, _) | (_, RefAccessRefType::Ref { .. }) => { + RefAccessRefType::Ref { + ref_id: next_ref_id(), + } + } + ( + RefAccessRefType::Structure { + value: a_value, + fn_type: a_fn, + }, + RefAccessRefType::Structure { + value: b_value, + fn_type: b_fn, + }, + ) => { + let fn_type = match (a_fn, b_fn) { + (None, other) | (other, None) => other.clone(), + (Some(a_fn), Some(b_fn)) => Some(RefFnType { + read_ref_effect: a_fn.read_ref_effect || b_fn.read_ref_effect, + return_type: Box::new(join_ref_access_types( + &a_fn.return_type, + &b_fn.return_type, + )), + }), + }; + let value = match (a_value, b_value) { + (None, other) | (other, None) => other.clone(), + (Some(a_val), Some(b_val)) => { + Some(Box::new(join_ref_access_ref_types(a_val, b_val))) + } + }; + RefAccessRefType::Structure { value, fn_type } + } + } +} + +fn join_ref_access_types(a: &RefAccessType, b: &RefAccessType) -> RefAccessType { + match (a, b) { + (RefAccessType::None, other) | (other, RefAccessType::None) => other.clone(), + (RefAccessType::Guard { ref_id: a_id }, RefAccessType::Guard { ref_id: b_id }) => { + if a_id == b_id { + a.clone() + } else { + RefAccessType::None + } + } + (RefAccessType::Guard { .. }, RefAccessType::Nullable) + | (RefAccessType::Nullable, RefAccessType::Guard { .. }) => RefAccessType::None, + (RefAccessType::Guard { .. }, other) | (other, RefAccessType::Guard { .. }) => { + other.clone() + } + (RefAccessType::Nullable, other) | (other, RefAccessType::Nullable) => other.clone(), + _ => match (a.to_ref_type(), b.to_ref_type()) { + (Some(a_ref), Some(b_ref)) => { + RefAccessType::from_ref_type(&join_ref_access_ref_types(&a_ref, &b_ref)) + } + (Some(r), None) | (None, Some(r)) => RefAccessType::from_ref_type(&r), + _ => RefAccessType::None, + }, + } +} + +fn join_ref_access_types_many(types: &[RefAccessType]) -> RefAccessType { + types + .iter() + .fold(RefAccessType::None, |acc, t| join_ref_access_types(&acc, t)) +} + +// --- Env --- + +struct Env { + changed: bool, + data: HashMap, + temporaries: HashMap, +} + +impl Env { + fn new() -> Self { + Self { + changed: false, + data: HashMap::new(), + temporaries: HashMap::new(), + } + } + + fn define(&mut self, key: IdentifierId, value: Place) { + self.temporaries.insert(key, value); + } + + fn reset_changed(&mut self) { + self.changed = false; + } + + fn has_changed(&self) -> bool { + self.changed + } + + fn get(&self, key: IdentifierId) -> Option<&RefAccessType> { + let operand_id = self + .temporaries + .get(&key) + .map(|p| p.identifier) + .unwrap_or(key); + self.data.get(&operand_id) + } + + fn set(&mut self, key: IdentifierId, value: RefAccessType) { + let operand_id = self + .temporaries + .get(&key) + .map(|p| p.identifier) + .unwrap_or(key); + let current = self.data.get(&operand_id); + let widened_value = join_ref_access_types(&value, current.unwrap_or(&RefAccessType::None)); + if current.is_none() && widened_value == RefAccessType::None { + // No change needed + } else if current.map_or(true, |c| c != &widened_value) { + self.changed = true; + } + self.data.insert(operand_id, widened_value); + } +} + +// --- Helper functions --- + +fn ref_type_of_type(id: IdentifierId, identifiers: &[Identifier], types: &[Type]) -> RefAccessType { + let identifier = &identifiers[id.0 as usize]; + let ty = &types[identifier.type_.0 as usize]; + if react_compiler_hir::is_ref_value_type(ty) { + RefAccessType::RefValue { + loc: None, + ref_id: None, + } + } else if react_compiler_hir::is_use_ref_type(ty) { + RefAccessType::Ref { + ref_id: next_ref_id(), + } + } else { + RefAccessType::None + } +} + +fn is_ref_type(id: IdentifierId, identifiers: &[Identifier], types: &[Type]) -> bool { + let identifier = &identifiers[id.0 as usize]; + react_compiler_hir::is_use_ref_type(&types[identifier.type_.0 as usize]) +} + +fn is_ref_value_type(id: IdentifierId, identifiers: &[Identifier], types: &[Type]) -> bool { + let identifier = &identifiers[id.0 as usize]; + react_compiler_hir::is_ref_value_type(&types[identifier.type_.0 as usize]) +} + +fn destructure(ty: &RefAccessType) -> RefAccessType { + match ty { + RefAccessType::Structure { + value: Some(inner), .. + } => destructure(&RefAccessType::from_ref_type(inner)), + other => other.clone(), + } +} + +// --- Validation helpers --- + +fn validate_no_direct_ref_value_access( + errors: &mut Vec, + operand: &Place, + env: &Env, +) { + if let Some(ty) = env.get(operand.identifier) { + let ty = destructure(ty); + if let RefAccessType::RefValue { loc, .. } = &ty { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: loc.or(operand.loc), + message: Some("Cannot access ref value during render".to_string()), + identifier_name: None, + }), + ); + } + } +} + +fn validate_no_ref_value_access(errors: &mut Vec, env: &Env, operand: &Place) { + if let Some(ty) = env.get(operand.identifier) { + let ty = destructure(ty); + match &ty { + RefAccessType::RefValue { loc, .. } => { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: loc.or(operand.loc), + message: Some("Cannot access ref value during render".to_string()), + identifier_name: None, + }), + ); + } + RefAccessType::Structure { + fn_type: Some(fn_type), + .. + } if fn_type.read_ref_effect => { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: operand.loc, + message: Some("Cannot access ref value during render".to_string()), + identifier_name: None, + }), + ); + } + _ => {} + } + } +} + +fn validate_no_ref_passed_to_function( + errors: &mut Vec, + env: &Env, + operand: &Place, + loc: Option, +) { + if let Some(ty) = env.get(operand.identifier) { + let ty = destructure(ty); + match &ty { + RefAccessType::Ref { .. } | RefAccessType::RefValue { .. } => { + let error_loc = if let RefAccessType::RefValue { loc: ref_loc, .. } = &ty { + ref_loc.or(loc) + } else { + loc + }; + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: error_loc, + message: Some( + "Passing a ref to a function may read its value during render" + .to_string(), + ), + identifier_name: None, + }), + ); + } + RefAccessType::Structure { + fn_type: Some(fn_type), + .. + } if fn_type.read_ref_effect => { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some( + "Passing a ref to a function may read its value during render" + .to_string(), + ), + identifier_name: None, + }), + ); + } + _ => {} + } + } +} + +fn validate_no_ref_update( + errors: &mut Vec, + env: &Env, + operand: &Place, + loc: Option, +) { + if let Some(ty) = env.get(operand.identifier) { + let ty = destructure(ty); + match &ty { + RefAccessType::Ref { .. } | RefAccessType::RefValue { .. } => { + let error_loc = if let RefAccessType::RefValue { loc: ref_loc, .. } = &ty { + ref_loc.or(loc) + } else { + loc + }; + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: error_loc, + message: Some("Cannot update ref during render".to_string()), + identifier_name: None, + }), + ); + } + _ => {} + } + } +} + +fn guard_check(errors: &mut Vec, operand: &Place, env: &Env) { + if matches!( + env.get(operand.identifier), + Some(RefAccessType::Guard { .. }) + ) { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: operand.loc, + message: Some("Cannot access ref value during render".to_string()), + identifier_name: None, + }), + ); + } +} + +// --- Main entry point --- + +pub fn validate_no_ref_access_in_render(func: &HirFunction, env: &mut Environment) { + let mut ref_env = Env::new(); + collect_temporaries_sidemap(func, &mut ref_env, &env.identifiers, &env.types); + let mut errors: Vec = Vec::new(); + validate_no_ref_access_in_render_impl( + func, + &env.identifiers, + &env.types, + &env.functions, + &*env, + &mut ref_env, + &mut errors, + ); + for diagnostic in errors { + env.record_diagnostic(diagnostic); + } +} + +fn collect_temporaries_sidemap( + func: &HirFunction, + env: &mut Env, + identifiers: &[Identifier], + types: &[Type], +) { + for (_, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::LoadLocal { place, .. } => { + let temp = env + .temporaries + .get(&place.identifier) + .cloned() + .unwrap_or_else(|| place.clone()); + env.define(instr.lvalue.identifier, temp); + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + let temp = env + .temporaries + .get(&value.identifier) + .cloned() + .unwrap_or_else(|| value.clone()); + env.define(instr.lvalue.identifier, temp.clone()); + env.define(lvalue.place.identifier, temp); + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + if is_ref_type(object.identifier, identifiers, types) + && *property == PropertyLiteral::String("current".to_string()) + { + continue; + } + let temp = env + .temporaries + .get(&object.identifier) + .cloned() + .unwrap_or_else(|| object.clone()); + env.define(instr.lvalue.identifier, temp); + } + _ => {} + } + } + } +} + +fn validate_no_ref_access_in_render_impl( + func: &HirFunction, + identifiers: &[Identifier], + types: &[Type], + functions: &[HirFunction], + env: &Environment, + ref_env: &mut Env, + errors: &mut Vec, +) -> RefAccessType { + let mut return_values: Vec = Vec::new(); + + // Process params + for param in &func.params { + let place = match param { + react_compiler_hir::ParamPattern::Place(p) => p, + react_compiler_hir::ParamPattern::Spread(s) => &s.place, + }; + ref_env.set( + place.identifier, + ref_type_of_type(place.identifier, identifiers, types), + ); + } + + // Collect identifiers that are interpolated as JSX children + let mut interpolated_as_jsx: HashSet = HashSet::new(); + for (_, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::JsxExpression { + children: Some(children), + .. + } => { + for child in children { + interpolated_as_jsx.insert(child.identifier); + } + } + InstructionValue::JsxFragment { children, .. } => { + for child in children { + interpolated_as_jsx.insert(child.identifier); + } + } + _ => {} + } + } + } + + // Fixed-point iteration (up to 10 iterations) + for iteration in 0..10 { + if iteration > 0 && !ref_env.has_changed() { + break; + } + ref_env.reset_changed(); + return_values.clear(); + let mut safe_blocks: Vec<(BlockId, RefId)> = Vec::new(); + + for (_, block) in &func.body.blocks { + safe_blocks.retain(|(block_id, _)| *block_id != block.id); + + // Process phis + for phi in &block.phis { + let phi_types: Vec = phi + .operands + .values() + .map(|operand| { + ref_env + .get(operand.identifier) + .cloned() + .unwrap_or(RefAccessType::None) + }) + .collect(); + ref_env.set(phi.place.identifier, join_ref_access_types_many(&phi_types)); + } + + // Process instructions + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::JsxExpression { .. } + | InstructionValue::JsxFragment { .. } => { + for operand in &canonical_each_instruction_value_operand(&instr.value, env) + { + validate_no_direct_ref_value_access(errors, operand, ref_env); + } + } + InstructionValue::ComputedLoad { + object, property, .. + } => { + validate_no_direct_ref_value_access(errors, property, ref_env); + let obj_type = ref_env.get(object.identifier).cloned(); + let lookup_type = match &obj_type { + Some(RefAccessType::Structure { + value: Some(value), .. + }) => Some(RefAccessType::from_ref_type(value)), + Some(RefAccessType::Ref { ref_id }) => Some(RefAccessType::RefValue { + loc: instr.loc, + ref_id: Some(*ref_id), + }), + _ => None, + }; + ref_env.set( + instr.lvalue.identifier, + lookup_type.unwrap_or_else(|| { + ref_type_of_type(instr.lvalue.identifier, identifiers, types) + }), + ); + } + InstructionValue::PropertyLoad { object, .. } => { + let obj_type = ref_env.get(object.identifier).cloned(); + let lookup_type = match &obj_type { + Some(RefAccessType::Structure { + value: Some(value), .. + }) => Some(RefAccessType::from_ref_type(value)), + Some(RefAccessType::Ref { ref_id }) => Some(RefAccessType::RefValue { + loc: instr.loc, + ref_id: Some(*ref_id), + }), + _ => None, + }; + ref_env.set( + instr.lvalue.identifier, + lookup_type.unwrap_or_else(|| { + ref_type_of_type(instr.lvalue.identifier, identifiers, types) + }), + ); + } + InstructionValue::TypeCastExpression { value, .. } => { + ref_env.set( + instr.lvalue.identifier, + ref_env.get(value.identifier).cloned().unwrap_or_else(|| { + ref_type_of_type(instr.lvalue.identifier, identifiers, types) + }), + ); + } + InstructionValue::LoadContext { place, .. } + | InstructionValue::LoadLocal { place, .. } => { + ref_env.set( + instr.lvalue.identifier, + ref_env.get(place.identifier).cloned().unwrap_or_else(|| { + ref_type_of_type(instr.lvalue.identifier, identifiers, types) + }), + ); + } + InstructionValue::StoreContext { lvalue, value, .. } + | InstructionValue::StoreLocal { lvalue, value, .. } => { + ref_env.set( + lvalue.place.identifier, + ref_env.get(value.identifier).cloned().unwrap_or_else(|| { + ref_type_of_type(lvalue.place.identifier, identifiers, types) + }), + ); + ref_env.set( + instr.lvalue.identifier, + ref_env.get(value.identifier).cloned().unwrap_or_else(|| { + ref_type_of_type(instr.lvalue.identifier, identifiers, types) + }), + ); + } + InstructionValue::Destructure { value, lvalue, .. } => { + let obj_type = ref_env.get(value.identifier).cloned(); + let lookup_type = match &obj_type { + Some(RefAccessType::Structure { + value: Some(value), .. + }) => Some(RefAccessType::from_ref_type(value)), + _ => None, + }; + ref_env.set( + instr.lvalue.identifier, + lookup_type.clone().unwrap_or_else(|| { + ref_type_of_type(instr.lvalue.identifier, identifiers, types) + }), + ); + for pattern_place in each_pattern_operand(&lvalue.pattern) { + ref_env.set( + pattern_place.identifier, + lookup_type.clone().unwrap_or_else(|| { + ref_type_of_type(pattern_place.identifier, identifiers, types) + }), + ); + } + } + InstructionValue::ObjectMethod { lowered_func, .. } + | InstructionValue::FunctionExpression { lowered_func, .. } => { + let inner = &functions[lowered_func.func.0 as usize]; + let mut inner_errors: Vec = Vec::new(); + let result = validate_no_ref_access_in_render_impl( + inner, + identifiers, + types, + functions, + env, + ref_env, + &mut inner_errors, + ); + let (return_type, read_ref_effect) = if inner_errors.is_empty() { + (result, false) + } else { + (RefAccessType::None, true) + }; + ref_env.set( + instr.lvalue.identifier, + RefAccessType::Structure { + value: None, + fn_type: Some(RefFnType { + read_ref_effect, + return_type: Box::new(return_type), + }), + }, + ); + } + InstructionValue::MethodCall { property, .. } + | InstructionValue::CallExpression { + callee: property, .. + } => { + let callee = property; + let mut return_type = RefAccessType::None; + let fn_type = ref_env.get(callee.identifier).cloned(); + let mut did_error = false; + + if let Some(RefAccessType::Structure { + fn_type: Some(fn_ty), + .. + }) = &fn_type + { + return_type = *fn_ty.return_type.clone(); + if fn_ty.read_ref_effect { + did_error = true; + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: callee.loc, + message: Some( + "This function accesses a ref value".to_string(), + ), + identifier_name: None, + }, + ), + ); + } + } + + /* + * If we already reported an error on this instruction, don't report + * duplicate errors + */ + if !did_error { + let is_ref_lvalue = + is_ref_type(instr.lvalue.identifier, identifiers, types); + let callee_identifier = &identifiers[callee.identifier.0 as usize]; + let callee_type = &types[callee_identifier.type_.0 as usize]; + let hook_kind = env.get_hook_kind_for_type(callee_type).ok().flatten(); + + if is_ref_lvalue + || (hook_kind.is_some() + && !matches!(hook_kind, Some(&HookKind::UseState)) + && !matches!(hook_kind, Some(&HookKind::UseReducer))) + { + for operand in + &canonical_each_instruction_value_operand(&instr.value, env) + { + /* + * Allow passing refs or ref-accessing functions when: + * 1. lvalue is a ref (mergeRefs pattern) + * 2. calling hooks (independently validated) + */ + validate_no_direct_ref_value_access(errors, operand, ref_env); + } + } else if interpolated_as_jsx.contains(&instr.lvalue.identifier) { + for operand in + &canonical_each_instruction_value_operand(&instr.value, env) + { + /* + * Special case: the lvalue is passed as a jsx child + */ + validate_no_ref_value_access(errors, ref_env, operand); + } + } else if hook_kind.is_none() { + if let Some(ref effects) = instr.effects { + /* + * For non-hook functions with known aliasing effects, + * use the effects to determine what validation to apply. + * Track visited id:kind pairs to avoid duplicate errors. + */ + let mut visited_effects: HashSet = HashSet::new(); + for effect in effects { + let (place, validation) = match effect { + AliasingEffect::Freeze { value, .. } => { + (Some(value), "direct-ref") + } + AliasingEffect::Mutate { value, .. } + | AliasingEffect::MutateTransitive { value, .. } + | AliasingEffect::MutateConditionally { + value, .. + } + | AliasingEffect::MutateTransitiveConditionally { + value, + .. + } => (Some(value), "ref-passed"), + AliasingEffect::Render { place, .. } => { + (Some(place), "ref-passed") + } + AliasingEffect::Capture { from, .. } + | AliasingEffect::Alias { from, .. } + | AliasingEffect::MaybeAlias { from, .. } + | AliasingEffect::Assign { from, .. } + | AliasingEffect::CreateFrom { from, .. } => { + (Some(from), "ref-passed") + } + AliasingEffect::ImmutableCapture { from, .. } => { + /* + * ImmutableCapture: check whether the same + * operand also has a Freeze effect to + * distinguish known signatures from + * downgraded defaults. + */ + let is_frozen = effects.iter().any(|e| { + matches!( + e, + AliasingEffect::Freeze { value, .. } + if value.identifier == from.identifier + ) + }); + ( + Some(from), + if is_frozen { + "direct-ref" + } else { + "ref-passed" + }, + ) + } + _ => (None, "none"), + }; + if let Some(place) = place { + if validation != "none" { + let key = format!( + "{}:{}", + place.identifier.0, validation + ); + if visited_effects.insert(key) { + if validation == "direct-ref" { + validate_no_direct_ref_value_access( + errors, place, ref_env, + ); + } else { + validate_no_ref_passed_to_function( + errors, ref_env, place, place.loc, + ); + } + } + } + } + } + } else { + for operand in + &canonical_each_instruction_value_operand(&instr.value, env) + { + validate_no_ref_passed_to_function( + errors, + ref_env, + operand, + operand.loc, + ); + } + } + } else { + for operand in + &canonical_each_instruction_value_operand(&instr.value, env) + { + validate_no_ref_passed_to_function( + errors, + ref_env, + operand, + operand.loc, + ); + } + } + } + ref_env.set(instr.lvalue.identifier, return_type); + } + InstructionValue::ObjectExpression { .. } + | InstructionValue::ArrayExpression { .. } => { + let operands = canonical_each_instruction_value_operand(&instr.value, env); + let mut types_vec: Vec = Vec::new(); + for operand in &operands { + validate_no_direct_ref_value_access(errors, operand, ref_env); + types_vec.push( + ref_env + .get(operand.identifier) + .cloned() + .unwrap_or(RefAccessType::None), + ); + } + let value = join_ref_access_types_many(&types_vec); + match &value { + RefAccessType::None + | RefAccessType::Guard { .. } + | RefAccessType::Nullable => { + ref_env.set(instr.lvalue.identifier, RefAccessType::None); + } + _ => { + ref_env.set( + instr.lvalue.identifier, + RefAccessType::Structure { + value: value.to_ref_type().map(Box::new), + fn_type: None, + }, + ); + } + } + } + InstructionValue::PropertyDelete { object, .. } + | InstructionValue::PropertyStore { object, .. } + | InstructionValue::ComputedDelete { object, .. } + | InstructionValue::ComputedStore { object, .. } => { + let target = ref_env.get(object.identifier).cloned(); + let mut found_safe = false; + if matches!(&instr.value, InstructionValue::PropertyStore { .. }) { + if let Some(RefAccessType::Ref { ref_id }) = &target { + if let Some(pos) = safe_blocks.iter().position(|(_, r)| r == ref_id) + { + safe_blocks.remove(pos); + found_safe = true; + } + } + } + if !found_safe { + validate_no_ref_update(errors, ref_env, object, instr.loc); + } + match &instr.value { + InstructionValue::ComputedDelete { property, .. } + | InstructionValue::ComputedStore { property, .. } => { + validate_no_ref_value_access(errors, ref_env, property); + } + _ => {} + } + match &instr.value { + InstructionValue::ComputedStore { value, .. } + | InstructionValue::PropertyStore { value, .. } => { + validate_no_direct_ref_value_access(errors, value, ref_env); + let value_type = ref_env.get(value.identifier).cloned(); + if let Some(RefAccessType::Structure { .. }) = &value_type { + let mut object_type = value_type.unwrap(); + if let Some(t) = &target { + object_type = join_ref_access_types(&object_type, t); + } + ref_env.set(object.identifier, object_type); + } + } + _ => {} + } + } + InstructionValue::StartMemoize { .. } + | InstructionValue::FinishMemoize { .. } => {} + InstructionValue::LoadGlobal { binding, .. } => { + if binding.name() == "undefined" { + ref_env.set(instr.lvalue.identifier, RefAccessType::Nullable); + } + } + InstructionValue::Primitive { value, .. } => { + if matches!(value, PrimitiveValue::Null | PrimitiveValue::Undefined) { + ref_env.set(instr.lvalue.identifier, RefAccessType::Nullable); + } + } + InstructionValue::UnaryExpression { + operator, value, .. + } => { + if *operator == UnaryOperator::Not { + if let Some(RefAccessType::RefValue { + ref_id: Some(ref_id), + .. + }) = ref_env.get(value.identifier).cloned().as_ref() + { + /* + * Record an error suggesting the `if (ref.current == null)` + * pattern, but also record the + * lvalue as a guard so that we don't emit a + * second error for the write to the ref + */ + ref_env.set( + instr.lvalue.identifier, + RefAccessType::Guard { ref_id: *ref_id }, + ); + errors.push( + CompilerDiagnostic::new( + ErrorCategory::Refs, + "Cannot access refs during render", + Some(ERROR_DESCRIPTION.to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: value.loc, + message: Some( + "Cannot access ref value during render".to_string(), + ), + identifier_name: None, + }) + .with_detail( + CompilerDiagnosticDetail::Hint { + message: "To initialize a ref only once, check that \ + the ref is null with the pattern `if \ + (ref.current == null) { ref.current = ... }`" + .to_string(), + }, + ), + ); + } else { + validate_no_ref_value_access(errors, ref_env, value); + } + } else { + validate_no_ref_value_access(errors, ref_env, value); + } + } + InstructionValue::BinaryExpression { left, right, .. } => { + let left_type = ref_env.get(left.identifier).cloned(); + let right_type = ref_env.get(right.identifier).cloned(); + let mut nullish = false; + let mut found_ref_id: Option = None; + + if let Some(RefAccessType::RefValue { + ref_id: Some(id), .. + }) = &left_type + { + found_ref_id = Some(*id); + } else if let Some(RefAccessType::RefValue { + ref_id: Some(id), .. + }) = &right_type + { + found_ref_id = Some(*id); + } + + if matches!(&left_type, Some(RefAccessType::Nullable)) { + nullish = true; + } else if matches!(&right_type, Some(RefAccessType::Nullable)) { + nullish = true; + } + + if let Some(ref_id) = found_ref_id { + if nullish { + ref_env + .set(instr.lvalue.identifier, RefAccessType::Guard { ref_id }); + } else { + validate_no_ref_value_access(errors, ref_env, left); + validate_no_ref_value_access(errors, ref_env, right); + } + } else { + validate_no_ref_value_access(errors, ref_env, left); + validate_no_ref_value_access(errors, ref_env, right); + } + } + _ => { + for operand in &canonical_each_instruction_value_operand(&instr.value, env) + { + validate_no_ref_value_access(errors, ref_env, operand); + } + } + } + + // Guard values are derived from ref.current, so they can only be used + // in if statement targets + for operand in &canonical_each_instruction_value_operand(&instr.value, env) { + guard_check(errors, operand, ref_env); + } + + if is_ref_type(instr.lvalue.identifier, identifiers, types) + && !matches!( + ref_env.get(instr.lvalue.identifier), + Some(RefAccessType::Ref { .. }) + ) + { + let existing = ref_env + .get(instr.lvalue.identifier) + .cloned() + .unwrap_or(RefAccessType::None); + ref_env.set( + instr.lvalue.identifier, + join_ref_access_types( + &existing, + &RefAccessType::Ref { + ref_id: next_ref_id(), + }, + ), + ); + } + if is_ref_value_type(instr.lvalue.identifier, identifiers, types) + && !matches!( + ref_env.get(instr.lvalue.identifier), + Some(RefAccessType::RefValue { .. }) + ) + { + let existing = ref_env + .get(instr.lvalue.identifier) + .cloned() + .unwrap_or(RefAccessType::None); + ref_env.set( + instr.lvalue.identifier, + join_ref_access_types( + &existing, + &RefAccessType::RefValue { + loc: instr.loc, + ref_id: None, + }, + ), + ); + } + } + + // Check if terminal is an `if` — push safe block for guard + if let Terminal::If { + test, fallthrough, .. + } = &block.terminal + { + if let Some(RefAccessType::Guard { ref_id }) = ref_env.get(test.identifier) { + if !safe_blocks.iter().any(|(_, r)| r == ref_id) { + safe_blocks.push((*fallthrough, *ref_id)); + } + } + } + + // Process terminal operands + for operand in &each_terminal_operand(&block.terminal) { + if !matches!(&block.terminal, Terminal::Return { .. }) { + validate_no_ref_value_access(errors, ref_env, operand); + if !matches!(&block.terminal, Terminal::If { .. }) { + guard_check(errors, operand, ref_env); + } + } else { + // Allow functions containing refs to be returned, but not direct ref values + validate_no_direct_ref_value_access(errors, operand, ref_env); + guard_check(errors, operand, ref_env); + if let Some(ty) = ref_env.get(operand.identifier) { + return_values.push(ty.clone()); + } + } + } + } + + if !errors.is_empty() { + return RefAccessType::None; + } + } + + // Note: the TS asserts convergence here, but the Rust fixpoint loop + // may not converge within MAX_ITERATIONS for some inputs yet. + + join_ref_access_types_many(&return_values) +} diff --git a/crates/react_compiler_validation/src/validate_no_set_state_in_effects.rs b/crates/react_compiler_validation/src/validate_no_set_state_in_effects.rs new file mode 100644 index 000000000000..9da3c247961d --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_set_state_in_effects.rs @@ -0,0 +1,581 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates against calling setState in the body of an effect (useEffect and +//! friends), while allowing calling setState in callbacks scheduled by the +//! effect. +//! +//! Calling setState during execution of a useEffect triggers a re-render, which +//! is often bad for performance and frequently has more efficient and +//! straightforward alternatives. See https://react.dev/learn/you-might-not-need-an-effect for examples. +//! +//! Port of ValidateNoSetStateInEffects.ts. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, ErrorCategory, +}; +use react_compiler_hir::{ + dominator::{compute_post_dominator_tree, post_dominator_frontier}, + environment::Environment, + is_ref_value_type, is_set_state_type, is_use_effect_event_type, is_use_effect_hook_type, + is_use_insertion_effect_hook_type, is_use_layout_effect_hook_type, is_use_ref_type, visitors, + BlockId, HirFunction, Identifier, IdentifierId, IdentifierName, InstructionValue, + PlaceOrSpread, PropertyLiteral, SourceLocation, Terminal, Type, +}; + +pub fn validate_no_set_state_in_effects( + func: &HirFunction, + env: &Environment, +) -> Result { + let identifiers = &env.identifiers; + let types = &env.types; + let functions = &env.functions; + let enable_verbose = env.config.enable_verbose_no_set_state_in_effect; + let enable_allow_set_state_from_refs = env.config.enable_allow_set_state_from_refs_in_effects; + + // Map from IdentifierId to the Place where the setState originated + let mut set_state_functions: HashMap = HashMap::new(); + let mut errors = CompilerError::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::LoadLocal { place, .. } => { + if set_state_functions.contains_key(&place.identifier) { + let info = set_state_functions[&place.identifier].clone(); + set_state_functions.insert(instr.lvalue.identifier, info); + } + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + if set_state_functions.contains_key(&value.identifier) { + let info = set_state_functions[&value.identifier].clone(); + set_state_functions.insert(lvalue.place.identifier, info.clone()); + set_state_functions.insert(instr.lvalue.identifier, info); + } + } + InstructionValue::FunctionExpression { lowered_func, .. } => { + // Check if any context capture references a setState + let inner_func = &functions[lowered_func.func.0 as usize]; + let has_set_state_operand = inner_func.context.iter().any(|ctx_place| { + is_set_state_type_by_id(ctx_place.identifier, identifiers, types) + || set_state_functions.contains_key(&ctx_place.identifier) + }); + + if has_set_state_operand { + let callee = get_set_state_call( + inner_func, + &mut set_state_functions, + identifiers, + types, + functions, + enable_allow_set_state_from_refs, + env.next_block_id_counter, + env.code.as_deref(), + )?; + if let Some(info) = callee { + set_state_functions.insert(instr.lvalue.identifier, info); + } + } + } + InstructionValue::MethodCall { property, args, .. } => { + let prop_type = + &types[identifiers[property.identifier.0 as usize].type_.0 as usize]; + if is_use_effect_event_type(prop_type) { + if let Some(first_arg) = args.first() { + if let PlaceOrSpread::Place(arg_place) = first_arg { + if let Some(info) = set_state_functions.get(&arg_place.identifier) { + set_state_functions + .insert(instr.lvalue.identifier, info.clone()); + } + } + } + } else if is_use_effect_hook_type(prop_type) + || is_use_layout_effect_hook_type(prop_type) + || is_use_insertion_effect_hook_type(prop_type) + { + if let Some(first_arg) = args.first() { + if let PlaceOrSpread::Place(arg_place) = first_arg { + if let Some(info) = set_state_functions.get(&arg_place.identifier) { + push_error(&mut errors, info, enable_verbose); + } + } + } + } + } + InstructionValue::CallExpression { callee, args, .. } => { + let callee_type = + &types[identifiers[callee.identifier.0 as usize].type_.0 as usize]; + if is_use_effect_event_type(callee_type) { + if let Some(first_arg) = args.first() { + if let PlaceOrSpread::Place(arg_place) = first_arg { + if let Some(info) = set_state_functions.get(&arg_place.identifier) { + set_state_functions + .insert(instr.lvalue.identifier, info.clone()); + } + } + } + } else if is_use_effect_hook_type(callee_type) + || is_use_layout_effect_hook_type(callee_type) + || is_use_insertion_effect_hook_type(callee_type) + { + if let Some(first_arg) = args.first() { + if let PlaceOrSpread::Place(arg_place) = first_arg { + if let Some(info) = set_state_functions.get(&arg_place.identifier) { + push_error(&mut errors, info, enable_verbose); + } + } + } + } + } + _ => {} + } + } + } + + Ok(errors) +} + +#[derive(Debug, Clone)] +struct SetStateInfo { + loc: Option, + identifier_name: Option, +} + +/// Get the user-visible name for an identifier, matching Babel's +/// loc.identifierName behavior. First checks the identifier's own name, +/// then falls back to extracting the name from the source code at the +/// given source location (the callee's loc). This handles SSA identifiers +/// whose names were lost during compiler passes. +fn get_identifier_name_with_loc( + id: IdentifierId, + identifiers: &[Identifier], + loc: &Option, + source_code: Option<&str>, +) -> Option { + let ident = &identifiers[id.0 as usize]; + if let Some(IdentifierName::Named(name)) = &ident.name { + return Some(name.clone()); + } + // Fall back to extracting from source code + if let (Some(loc), Some(code)) = (loc, source_code) { + let start_idx = loc.start.index? as usize; + let end_idx = loc.end.index? as usize; + if start_idx < code.len() && end_idx <= code.len() && start_idx < end_idx { + let slice = &code[start_idx..end_idx]; + if !slice.is_empty() + && slice + .chars() + .all(|c| c.is_alphanumeric() || c == '_' || c == '$') + { + return Some(slice.to_string()); + } + } + } + None +} + +fn is_set_state_type_by_id( + identifier_id: IdentifierId, + identifiers: &[Identifier], + types: &[Type], +) -> bool { + let ident = &identifiers[identifier_id.0 as usize]; + let ty = &types[ident.type_.0 as usize]; + is_set_state_type(ty) +} + +fn push_error(errors: &mut CompilerError, info: &SetStateInfo, enable_verbose: bool) { + if enable_verbose { + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::EffectSetState, + "Calling setState synchronously within an effect can trigger cascading renders", + Some( + "Effects are intended to synchronize state between React and external systems. \ + Calling setState synchronously causes cascading renders that hurt performance.\n\n\ + This pattern may indicate one of several issues:\n\n\ + **1. Non-local derived data**: If the value being set could be computed from props/state \ + but requires data from a parent component, consider restructuring state ownership so the \ + derivation can happen during render in the component that owns the relevant state.\n\n\ + **2. Derived event pattern**: If you're detecting when a prop changes (e.g., `isPlaying` \ + transitioning from false to true), this often indicates the parent should provide an event \ + callback (like `onPlay`) instead of just the current state. Request access to the original event.\n\n\ + **3. Force update / external sync**: If you're forcing a re-render to sync with an external \ + data source (mutable values outside React), use `useSyncExternalStore` to properly subscribe \ + to external state changes.\n\n\ + See: https://react.dev/learn/you-might-not-need-an-effect".to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: info.loc, + message: Some( + "Avoid calling setState() directly within an effect".to_string(), + ), + identifier_name: info.identifier_name.clone(), + }), + ); + } else { + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::EffectSetState, + "Calling setState synchronously within an effect can trigger cascading renders", + Some( + "Effects are intended to synchronize state between React and external systems such as manually updating the DOM, state management libraries, or other platform APIs. \ + In general, the body of an effect should do one or both of the following:\n\ + * Update external systems with the latest state from React.\n\ + * Subscribe for updates from some external system, calling setState in a callback function when external state changes.\n\n\ + Calling setState synchronously within an effect body causes cascading renders that can hurt performance, and is not recommended. \ + (https://react.dev/learn/you-might-not-need-an-effect)".to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: info.loc, + message: Some( + "Avoid calling setState() directly within an effect".to_string(), + ), + identifier_name: info.identifier_name.clone(), + }), + ); + } +} + +/// Recursively collect all Place identifiers from a destructure pattern. +fn collect_destructure_places( + pattern: &react_compiler_hir::Pattern, + ref_derived_values: &mut HashSet, +) { + match pattern { + react_compiler_hir::Pattern::Array(arr) => { + for item in &arr.items { + match item { + react_compiler_hir::ArrayPatternElement::Place(p) => { + ref_derived_values.insert(p.identifier); + } + react_compiler_hir::ArrayPatternElement::Spread(s) => { + ref_derived_values.insert(s.place.identifier); + } + react_compiler_hir::ArrayPatternElement::Hole => {} + } + } + } + react_compiler_hir::Pattern::Object(obj) => { + for prop in &obj.properties { + match prop { + react_compiler_hir::ObjectPropertyOrSpread::Property(p) => { + ref_derived_values.insert(p.place.identifier); + } + react_compiler_hir::ObjectPropertyOrSpread::Spread(s) => { + ref_derived_values.insert(s.place.identifier); + } + } + } + } + } +} + +fn is_derived_from_ref( + id: IdentifierId, + ref_derived_values: &HashSet, + identifiers: &[Identifier], + types: &[Type], +) -> bool { + if ref_derived_values.contains(&id) { + return true; + } + let ident = &identifiers[id.0 as usize]; + let ty = &types[ident.type_.0 as usize]; + is_use_ref_type(ty) || is_ref_value_type(ty) +} + +/// Collects all operand IdentifierIds from an instruction value. +/// Uses the canonical `each_instruction_value_operand_with_functions` from +/// visitors. +fn collect_operands(value: &InstructionValue, functions: &[HirFunction]) -> Vec { + visitors::each_instruction_value_operand_with_functions(value, functions) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Creates a function that checks whether a block is "control-dominated" by +/// a ref-derived condition. A block is ref-controlled if its post-dominator +/// frontier contains a block whose terminal tests a ref-derived value. +fn create_ref_controlled_block_checker( + func: &HirFunction, + next_block_id_counter: u32, + ref_derived_values: &HashSet, + identifiers: &[Identifier], + types: &[Type], +) -> Result, CompilerDiagnostic> { + let post_dominators = compute_post_dominator_tree(func, next_block_id_counter, false)?; + let mut cache: HashMap = HashMap::new(); + + for (block_id, _block) in &func.body.blocks { + let frontier = post_dominator_frontier(func, &post_dominators, *block_id); + let mut is_controlled = false; + + for frontier_block_id in &frontier { + let control_block = &func.body.blocks[frontier_block_id]; + match &control_block.terminal { + Terminal::If { test, .. } | Terminal::Branch { test, .. } => { + if is_derived_from_ref(test.identifier, ref_derived_values, identifiers, types) + { + is_controlled = true; + break; + } + } + Terminal::Switch { test, cases, .. } => { + if is_derived_from_ref(test.identifier, ref_derived_values, identifiers, types) + { + is_controlled = true; + break; + } + for case in cases { + if let Some(case_test) = &case.test { + if is_derived_from_ref( + case_test.identifier, + ref_derived_values, + identifiers, + types, + ) { + is_controlled = true; + break; + } + } + } + if is_controlled { + break; + } + } + _ => {} + } + } + + cache.insert(*block_id, is_controlled); + } + + Ok(cache) +} + +/// Checks inner function body for direct setState calls. Returns the callee +/// Place info if a setState call is found in the function body. +/// Tracks ref-derived values to allow setState when the value being set comes +/// from a ref. +fn get_set_state_call( + func: &HirFunction, + set_state_functions: &mut HashMap, + identifiers: &[Identifier], + types: &[Type], + functions: &[HirFunction], + enable_allow_set_state_from_refs: bool, + next_block_id_counter: u32, + source_code: Option<&str>, +) -> Result, CompilerDiagnostic> { + let mut ref_derived_values: HashSet = HashSet::new(); + + // First pass: collect ref-derived values (needed before building control + // dominator checker) We do a pre-pass to seed ref_derived_values so the + // control dominator checker has them. + if enable_allow_set_state_from_refs { + for (_block_id, block) in &func.body.blocks { + for phi in &block.phis { + let is_phi_derived = phi.operands.values().any(|operand| { + is_derived_from_ref(operand.identifier, &ref_derived_values, identifiers, types) + }); + if is_phi_derived { + ref_derived_values.insert(phi.place.identifier); + } + } + + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + + let operands = collect_operands(&instr.value, functions); + let has_ref_operand = operands.iter().any(|op_id| { + is_derived_from_ref(*op_id, &ref_derived_values, identifiers, types) + }); + + if has_ref_operand { + ref_derived_values.insert(instr.lvalue.identifier); + if let InstructionValue::Destructure { lvalue, .. } = &instr.value { + collect_destructure_places(&lvalue.pattern, &mut ref_derived_values); + } + if let InstructionValue::StoreLocal { lvalue, .. } = &instr.value { + ref_derived_values.insert(lvalue.place.identifier); + } + } + + if let InstructionValue::PropertyLoad { + object, property, .. + } = &instr.value + { + if *property == PropertyLiteral::String("current".to_string()) { + let obj_ident = &identifiers[object.identifier.0 as usize]; + let obj_ty = &types[obj_ident.type_.0 as usize]; + if is_use_ref_type(obj_ty) || is_ref_value_type(obj_ty) { + ref_derived_values.insert(instr.lvalue.identifier); + } + } + } + } + } + } + + // Build control dominator checker after collecting ref-derived values + let ref_controlled_blocks = if enable_allow_set_state_from_refs { + create_ref_controlled_block_checker( + func, + next_block_id_counter, + &ref_derived_values, + identifiers, + types, + )? + } else { + HashMap::new() + }; + + let is_ref_controlled_block = |block_id: BlockId| -> bool { + ref_controlled_blocks + .get(&block_id) + .copied() + .unwrap_or(false) + }; + + // Reset and redo: second pass with control dominator info available + ref_derived_values.clear(); + + for (_block_id, block) in &func.body.blocks { + // Track ref-derived values through phis + if enable_allow_set_state_from_refs { + for phi in &block.phis { + if is_derived_from_ref( + phi.place.identifier, + &ref_derived_values, + identifiers, + types, + ) { + continue; + } + let is_phi_derived = phi.operands.values().any(|operand| { + is_derived_from_ref(operand.identifier, &ref_derived_values, identifiers, types) + }); + if is_phi_derived { + ref_derived_values.insert(phi.place.identifier); + } else { + // Fallback: check if any predecessor block is ref-controlled + let mut found = false; + for pred in phi.operands.keys() { + if is_ref_controlled_block(*pred) { + ref_derived_values.insert(phi.place.identifier); + found = true; + break; + } + } + if found { + continue; + } + } + } + } + + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + + // Track ref-derived values through instructions + if enable_allow_set_state_from_refs { + let operands = collect_operands(&instr.value, functions); + let has_ref_operand = operands.iter().any(|op_id| { + is_derived_from_ref(*op_id, &ref_derived_values, identifiers, types) + }); + + if has_ref_operand { + ref_derived_values.insert(instr.lvalue.identifier); + // For Destructure, also mark all pattern places as ref-derived + if let InstructionValue::Destructure { lvalue, .. } = &instr.value { + collect_destructure_places(&lvalue.pattern, &mut ref_derived_values); + } + // For StoreLocal, propagate to the local variable + if let InstructionValue::StoreLocal { lvalue, .. } = &instr.value { + ref_derived_values.insert(lvalue.place.identifier); + } + } + + // Special case: PropertyLoad of .current on ref/refValue + if let InstructionValue::PropertyLoad { + object, property, .. + } = &instr.value + { + if *property == PropertyLiteral::String("current".to_string()) { + let obj_ident = &identifiers[object.identifier.0 as usize]; + let obj_ty = &types[obj_ident.type_.0 as usize]; + if is_use_ref_type(obj_ty) || is_ref_value_type(obj_ty) { + ref_derived_values.insert(instr.lvalue.identifier); + } + } + } + } + + match &instr.value { + InstructionValue::LoadLocal { place, .. } => { + if set_state_functions.contains_key(&place.identifier) { + let info = set_state_functions[&place.identifier].clone(); + set_state_functions.insert(instr.lvalue.identifier, info); + } + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + if set_state_functions.contains_key(&value.identifier) { + let info = set_state_functions[&value.identifier].clone(); + set_state_functions.insert(lvalue.place.identifier, info.clone()); + set_state_functions.insert(instr.lvalue.identifier, info); + } + } + InstructionValue::CallExpression { callee, args, .. } => { + if is_set_state_type_by_id(callee.identifier, identifiers, types) + || set_state_functions.contains_key(&callee.identifier) + { + if enable_allow_set_state_from_refs { + // Check if the first argument is ref-derived + if let Some(first_arg) = args.first() { + if let PlaceOrSpread::Place(arg_place) = first_arg { + if is_derived_from_ref( + arg_place.identifier, + &ref_derived_values, + identifiers, + types, + ) { + // Allow setState when value is derived from ref + return Ok(None); + } + } + } + // Check if the current block is controlled by a ref-derived condition + if is_ref_controlled_block(block.id) { + continue; + } + } + // Get the user-visible identifier name, matching Babel's + // loc.identifierName behavior. Uses declaration_id to find + // the original named identifier when SSA creates unnamed copies. + let callee_name = get_identifier_name_with_loc( + callee.identifier, + identifiers, + &callee.loc, + source_code, + ); + return Ok(Some(SetStateInfo { + loc: callee.loc, + identifier_name: callee_name, + })); + } + } + _ => {} + } + } + } + Ok(None) +} diff --git a/crates/react_compiler_validation/src/validate_no_set_state_in_render.rs b/crates/react_compiler_validation/src/validate_no_set_state_in_render.rs new file mode 100644 index 000000000000..b99deff50613 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_no_set_state_in_render.rs @@ -0,0 +1,190 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates that the function does not unconditionally call setState during +//! render. +//! +//! Port of ValidateNoSetStateInRender.ts. + +use std::collections::HashSet; + +use react_compiler_diagnostics::{CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory}; +use react_compiler_hir::{ + dominator::compute_unconditional_blocks, environment::Environment, BlockId, HirFunction, + Identifier, IdentifierId, InstructionValue, Type, +}; + +pub fn validate_no_set_state_in_render( + func: &HirFunction, + env: &mut Environment, +) -> Result<(), CompilerDiagnostic> { + let mut unconditional_set_state_functions: HashSet = HashSet::new(); + let next_block_id = env.next_block_id().0; + let diagnostics = validate_impl( + func, + &env.identifiers, + &env.types, + &env.functions, + next_block_id, + env.config.enable_use_keyed_state, + &mut unconditional_set_state_functions, + )?; + for diag in diagnostics { + env.record_diagnostic(diag); + } + Ok(()) +} + +fn is_set_state_id( + identifier_id: IdentifierId, + identifiers: &[Identifier], + types: &[Type], +) -> bool { + let ident = &identifiers[identifier_id.0 as usize]; + let ty = &types[ident.type_.0 as usize]; + react_compiler_hir::is_set_state_type(ty) +} + +fn validate_impl( + func: &HirFunction, + identifiers: &[Identifier], + types: &[Type], + functions: &[HirFunction], + next_block_id_counter: u32, + enable_use_keyed_state: bool, + unconditional_set_state_functions: &mut HashSet, +) -> Result, CompilerDiagnostic> { + let unconditional_blocks: HashSet = + compute_unconditional_blocks(func, next_block_id_counter)?; + let mut active_manual_memo_id: Option = None; + let mut errors: Vec = Vec::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + match &instr.value { + InstructionValue::LoadLocal { place, .. } => { + if unconditional_set_state_functions.contains(&place.identifier) { + unconditional_set_state_functions.insert(instr.lvalue.identifier); + } + } + InstructionValue::StoreLocal { lvalue, value, .. } => { + if unconditional_set_state_functions.contains(&value.identifier) { + unconditional_set_state_functions.insert(lvalue.place.identifier); + unconditional_set_state_functions.insert(instr.lvalue.identifier); + } + } + InstructionValue::ObjectMethod { lowered_func, .. } + | InstructionValue::FunctionExpression { lowered_func, .. } => { + let inner_func = &functions[lowered_func.func.0 as usize]; + + // Check if any operand references a setState. + // For FunctionExpression/ObjectMethod, operands are the context captures. + let has_set_state_operand = inner_func.context.iter().any(|ctx_place| { + is_set_state_id(ctx_place.identifier, identifiers, types) + || unconditional_set_state_functions.contains(&ctx_place.identifier) + }); + + if has_set_state_operand { + let inner_errors = validate_impl( + inner_func, + identifiers, + types, + functions, + next_block_id_counter, + enable_use_keyed_state, + unconditional_set_state_functions, + )?; + if !inner_errors.is_empty() { + unconditional_set_state_functions.insert(instr.lvalue.identifier); + } + } + } + InstructionValue::StartMemoize { manual_memo_id, .. } => { + assert!( + active_manual_memo_id.is_none(), + "Unexpected nested StartMemoize instructions" + ); + active_manual_memo_id = Some(*manual_memo_id); + } + InstructionValue::FinishMemoize { manual_memo_id, .. } => { + assert!( + active_manual_memo_id == Some(*manual_memo_id), + "Expected FinishMemoize to align with previous StartMemoize instruction" + ); + active_manual_memo_id = None; + } + InstructionValue::CallExpression { callee, .. } => { + if is_set_state_id(callee.identifier, identifiers, types) + || unconditional_set_state_functions.contains(&callee.identifier) + { + if active_manual_memo_id.is_some() { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::RenderSetState, + "Calling setState from useMemo may trigger an infinite loop", + Some( + "Each time the memo callback is evaluated it will change state. This can cause a memoization dependency to change, running the memo function again and causing an infinite loop. Instead of setting state in useMemo(), prefer deriving the value during render. (https://react.dev/reference/react/useState)".to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: callee.loc, + message: Some("Found setState() within useMemo()".to_string()), + identifier_name: None, + }), + ); + } else if unconditional_blocks.contains(&block.id) { + if enable_use_keyed_state { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::RenderSetState, + "Cannot call setState during render", + Some( + "Calling setState during render may trigger an \ + infinite loop.\n* To reset state when other \ + state/props change, use `const [state, setState] = \ + useKeyedState(initialState, key)` to reset `state` \ + when `key` changes.\n* To derive data from other \ + state/props, compute the derived data during render \ + without using state" + .to_string(), + ), + ) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: callee.loc, + message: Some("Found setState() in render".to_string()), + identifier_name: None, + }, + ), + ); + } else { + errors.push( + CompilerDiagnostic::new( + ErrorCategory::RenderSetState, + "Cannot call setState during render", + Some( + "Calling setState during render may trigger an infinite loop.\n\ + * To reset state when other state/props change, store the previous value in state and update conditionally: https://react.dev/reference/react/useState#storing-information-from-previous-renders\n\ + * To derive data from other state/props, compute the derived data during render without using state".to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: callee.loc, + message: Some("Found setState() in render".to_string()), + identifier_name: None, + }), + ); + } + } + } + } + _ => {} + } + } + } + + Ok(errors) +} diff --git a/crates/react_compiler_validation/src/validate_preserved_manual_memoization.rs b/crates/react_compiler_validation/src/validate_preserved_manual_memoization.rs new file mode 100644 index 000000000000..59e46e5b9b53 --- /dev/null +++ b/crates/react_compiler_validation/src/validate_preserved_manual_memoization.rs @@ -0,0 +1,782 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Port of ValidatePreservedManualMemoization.ts +//! +//! Validates that all explicit manual memoization (useMemo/useCallback) was +//! accurately preserved, and that no originally memoized values became +//! unmemoized in the output. + +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, DeclarationId, DependencyPathEntry, Identifier, IdentifierId, + IdentifierName, InstructionKind, InstructionValue, ManualMemoDependency, + ManualMemoDependencyRoot, Place, ReactiveBlock, ReactiveFunction, ReactiveInstruction, + ReactiveScopeBlock, ReactiveStatement, ReactiveValue, ScopeId, +}; + +/// State tracked during manual memo validation within a +/// StartMemoize..FinishMemoize range. +struct ManualMemoBlockState { + /// Reassigned temporaries (declaration_id -> set of identifier ids that + /// were reassigned to it). + reassignments: HashMap>, + /// Source location of the StartMemoize instruction. + loc: Option, + /// Declarations produced within this manual memo block. + decls: HashSet, + /// Normalized deps from source (useMemo/useCallback dep array). + deps_from_source: Option>, + /// Manual memo id from StartMemoize. + manual_memo_id: u32, +} + +/// Top-level visitor state. +struct VisitorState<'a> { + env: &'a mut Environment, + manual_memo_state: Option, + /// Completed (non-pruned) scope IDs. + scopes: HashSet, + /// Completed pruned scope IDs. + pruned_scopes: HashSet, + /// Map from identifier ID to its normalized manual memo dependency. + temporaries: HashMap, +} + +/// Validate that manual memoization (useMemo/useCallback) is preserved. +/// +/// Walks the reactive function looking for StartMemoize/FinishMemoize +/// instructions and checks that: +/// 1. Dependencies' scopes have completed before the memo block starts +/// 2. Memoized values are actually within scopes (not unmemoized) +/// 3. Inferred scope dependencies match the source dependencies +pub fn validate_preserved_manual_memoization(func: &ReactiveFunction, env: &mut Environment) { + let mut state = VisitorState { + env, + manual_memo_state: None, + scopes: HashSet::new(), + pruned_scopes: HashSet::new(), + temporaries: HashMap::new(), + }; + visit_block(&func.body, &mut state); +} + +fn is_named(ident: &Identifier) -> bool { + matches!(ident.name, Some(IdentifierName::Named(_))) +} + +fn visit_block(block: &ReactiveBlock, state: &mut VisitorState) { + for stmt in block { + visit_statement(stmt, state); + } +} + +fn visit_statement(stmt: &ReactiveStatement, state: &mut VisitorState) { + match stmt { + ReactiveStatement::Instruction(instr) => { + visit_instruction(instr, state); + } + ReactiveStatement::Terminal(terminal) => { + visit_terminal(terminal, state); + } + ReactiveStatement::Scope(scope_block) => { + visit_scope(scope_block, state); + } + ReactiveStatement::PrunedScope(pruned) => { + visit_pruned_scope(pruned, state); + } + } +} + +fn visit_terminal( + terminal: &react_compiler_hir::ReactiveTerminalStatement, + state: &mut VisitorState, +) { + use react_compiler_hir::ReactiveTerminal; + match &terminal.terminal { + ReactiveTerminal::If { + consequent, + alternate, + .. + } => { + visit_block(consequent, state); + if let Some(alt) = alternate { + visit_block(alt, state); + } + } + ReactiveTerminal::Switch { cases, .. } => { + for case in cases { + if let Some(ref block) = case.block { + visit_block(block, state); + } + } + } + ReactiveTerminal::For { loop_block, .. } + | ReactiveTerminal::ForOf { loop_block, .. } + | ReactiveTerminal::ForIn { loop_block, .. } + | ReactiveTerminal::While { loop_block, .. } + | ReactiveTerminal::DoWhile { loop_block, .. } => { + visit_block(loop_block, state); + } + ReactiveTerminal::Label { block, .. } => { + visit_block(block, state); + } + ReactiveTerminal::Try { block, handler, .. } => { + visit_block(block, state); + visit_block(handler, state); + } + _ => {} + } +} + +fn visit_scope(scope_block: &ReactiveScopeBlock, state: &mut VisitorState) { + // Traverse the scope's instructions first + visit_block(&scope_block.instructions, state); + + // After traversing, validate scope dependencies against manual memo deps + if let Some(ref memo_state) = state.manual_memo_state { + if let Some(ref deps_from_source) = memo_state.deps_from_source { + let scope = &state.env.scopes[scope_block.scope.0 as usize]; + let deps = scope.dependencies.clone(); + let memo_loc = memo_state.loc; + let decls = memo_state.decls.clone(); + let deps_from_source = deps_from_source.clone(); + let temporaries = state.temporaries.clone(); + for dep in &deps { + validate_inferred_dep( + dep.identifier, + &dep.path, + &temporaries, + &decls, + &deps_from_source, + state.env, + memo_loc, + ); + } + } + } + + // Mark scope and merged scopes as completed + let scope = &state.env.scopes[scope_block.scope.0 as usize]; + let merged = scope.merged.clone(); + state.scopes.insert(scope_block.scope); + for merged_id in merged { + state.scopes.insert(merged_id); + } +} + +fn visit_pruned_scope( + pruned: &react_compiler_hir::PrunedReactiveScopeBlock, + state: &mut VisitorState, +) { + visit_block(&pruned.instructions, state); + state.pruned_scopes.insert(pruned.scope); +} + +fn visit_instruction(instr: &ReactiveInstruction, state: &mut VisitorState) { + // Record temporaries and deps in the instruction's value + record_temporaries(instr, state); + + match &instr.value { + ReactiveValue::Instruction(InstructionValue::StartMemoize { + manual_memo_id, + deps, + has_invalid_deps, + .. + }) => { + // TS: CompilerError.invariant(state.manualMemoState == null, ...) + assert!( + state.manual_memo_state.is_none(), + "Unexpected nested StartMemoize instructions" + ); + + // TS: if (value.hasInvalidDeps === true) { return; } + if *has_invalid_deps { + return; + } + + let deps_from_source = deps.clone(); + + state.manual_memo_state = Some(ManualMemoBlockState { + loc: instr.loc, + decls: HashSet::new(), + deps_from_source, + manual_memo_id: *manual_memo_id, + reassignments: HashMap::new(), + }); + + // Check that each dependency's scope has completed before the memo + // TS: for (const {identifier, loc} of eachInstructionValueOperand(value)) + let operand_places = start_memoize_operands(deps); + for place in &operand_places { + let ident = &state.env.identifiers[place.identifier.0 as usize]; + if let Some(scope_id) = ident.scope { + if !state.scopes.contains(&scope_id) && !state.pruned_scopes.contains(&scope_id) + { + let diag = CompilerDiagnostic::new( + ErrorCategory::PreserveManualMemo, + "Existing memoization could not be preserved", + Some( + "React Compiler has skipped optimizing this component because the \ + existing manual memoization could not be preserved. This \ + dependency may be mutated later, which could cause the value to \ + change unexpectedly" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: place.loc, + message: Some("This dependency may be modified later".to_string()), + identifier_name: None, + }); + state.env.record_diagnostic(diag); + } + } + } + } + ReactiveValue::Instruction(InstructionValue::FinishMemoize { + decl, + pruned, + manual_memo_id, + .. + }) => { + if state.manual_memo_state.is_none() { + // StartMemoize had invalid deps, skip validation + return; + } + + // TS: CompilerError.invariant(state.manualMemoState.manualMemoId === + // value.manualMemoId, ...) + assert!( + state.manual_memo_state.as_ref().unwrap().manual_memo_id == *manual_memo_id, + "Unexpected mismatch between StartMemoize and FinishMemoize" + ); + + let memo_state = state.manual_memo_state.take().unwrap(); + + if !pruned { + // Check if the declared value is unmemoized + let decl_ident = &state.env.identifiers[decl.identifier.0 as usize]; + + if decl_ident.scope.is_none() { + // If the manual memo was inlined (useMemo -> IIFE), check reassignments + let decls_to_check = memo_state + .reassignments + .get(&decl_ident.declaration_id) + .map(|ids| ids.iter().copied().collect::>()) + .unwrap_or_else(|| vec![decl.identifier]); + + for id in decls_to_check { + if is_unmemoized(id, &state.scopes, &state.env.identifiers) { + record_unmemoized_error(decl.loc, state.env); + } + } + } else { + // Single identifier with scope + if is_unmemoized(decl.identifier, &state.scopes, &state.env.identifiers) { + record_unmemoized_error(decl.loc, state.env); + } + } + } + } + ReactiveValue::Instruction(InstructionValue::StoreLocal { lvalue, value, .. }) => { + // Track reassignments from inlining of manual memo + if state.manual_memo_state.is_some() && lvalue.kind == InstructionKind::Reassign { + let decl_id = + state.env.identifiers[lvalue.place.identifier.0 as usize].declaration_id; + state + .manual_memo_state + .as_mut() + .unwrap() + .reassignments + .entry(decl_id) + .or_default() + .insert(value.identifier); + } + } + ReactiveValue::Instruction(InstructionValue::LoadLocal { place, .. }) => { + if state.manual_memo_state.is_some() { + let place_ident = &state.env.identifiers[place.identifier.0 as usize]; + if let Some(ref lvalue) = instr.lvalue { + let lvalue_ident = &state.env.identifiers[lvalue.identifier.0 as usize]; + if place_ident.scope.is_some() && lvalue_ident.scope.is_none() { + state + .manual_memo_state + .as_mut() + .unwrap() + .reassignments + .entry(lvalue_ident.declaration_id) + .or_default() + .insert(place.identifier); + } + } + } + } + _ => {} + } +} + +fn record_unmemoized_error(loc: Option, env: &mut Environment) { + let diag = CompilerDiagnostic::new( + ErrorCategory::PreserveManualMemo, + "Existing memoization could not be preserved", + Some( + "React Compiler has skipped optimizing this component because the existing manual \ + memoization could not be preserved. This value was memoized in source but not in \ + compilation output" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some("Could not preserve existing memoization".to_string()), + identifier_name: None, + }); + env.record_diagnostic(diag); +} + +/// Record temporaries from an instruction. +/// TS: `recordTemporaries` +fn record_temporaries(instr: &ReactiveInstruction, state: &mut VisitorState) { + let lvalue = &instr.lvalue; + let lv_id = lvalue.as_ref().map(|lv| lv.identifier); + if let Some(id) = lv_id { + if state.temporaries.contains_key(&id) { + return; + } + } + + if let Some(ref lvalue) = instr.lvalue { + let lv_ident = &state.env.identifiers[lvalue.identifier.0 as usize]; + if is_named(lv_ident) && state.manual_memo_state.is_some() { + state + .manual_memo_state + .as_mut() + .unwrap() + .decls + .insert(lv_ident.declaration_id); + } + } + + // Record deps from the instruction value first (before setting lvalue + // temporary) + record_deps_in_value(&instr.value, state); + + // Then set the lvalue temporary (TS always sets this, even for unnamed lvalues) + if let Some(ref lvalue) = instr.lvalue { + state.temporaries.insert( + lvalue.identifier, + ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: lvalue.clone(), + constant: false, + }, + path: Vec::new(), + loc: lvalue.loc, + }, + ); + } +} + +/// Record dependencies from a reactive value. +/// TS: `recordDepsInValue` +fn record_deps_in_value(value: &ReactiveValue, state: &mut VisitorState) { + match value { + ReactiveValue::SequenceExpression { + instructions, + value, + .. + } => { + for instr in instructions { + visit_instruction(instr, state); + } + record_deps_in_value(value, state); + } + ReactiveValue::OptionalExpression { value: inner, .. } => { + record_deps_in_value(inner, state); + } + ReactiveValue::ConditionalExpression { + test, + consequent, + alternate, + .. + } => { + record_deps_in_value(test, state); + record_deps_in_value(consequent, state); + record_deps_in_value(alternate, state); + } + ReactiveValue::LogicalExpression { left, right, .. } => { + record_deps_in_value(left, state); + record_deps_in_value(right, state); + } + ReactiveValue::Instruction(iv) => { + // TS: collectMaybeMemoDependencies(value, this.temporaries, false) + // Called for side-effect of building up the dependency chain through + // LoadGlobal -> PropertyLoad -> ... The return value is discarded here + // (only used in DropManualMemoization's caller), but we need to store + // the result in temporaries for the lvalue of the enclosing instruction. + // That storage is handled by record_temporaries after this function returns. + + // Track store targets within manual memo blocks + // TS: if (value.kind === 'StoreLocal' || value.kind === 'StoreContext' || + // value.kind === 'Destructure') + match iv { + InstructionValue::StoreLocal { lvalue, .. } + | InstructionValue::StoreContext { lvalue, .. } => { + if let Some(ref mut memo_state) = state.manual_memo_state { + let ident = &state.env.identifiers[lvalue.place.identifier.0 as usize]; + memo_state.decls.insert(ident.declaration_id); + if is_named(ident) { + state.temporaries.insert( + lvalue.place.identifier, + ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: lvalue.place.clone(), + constant: false, + }, + path: Vec::new(), + loc: lvalue.place.loc, + }, + ); + } + } + } + InstructionValue::Destructure { lvalue, .. } => { + if let Some(ref mut memo_state) = state.manual_memo_state { + for place in destructure_lvalue_places(&lvalue.pattern) { + let ident = &state.env.identifiers[place.identifier.0 as usize]; + memo_state.decls.insert(ident.declaration_id); + if is_named(ident) { + state.temporaries.insert( + place.identifier, + ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: place.clone(), + constant: false, + }, + path: Vec::new(), + loc: place.loc, + }, + ); + } + } + } + } + _ => {} + } + } + } +} + +/// Get operand places from a StartMemoize instruction's deps. +fn start_memoize_operands(deps: &Option>) -> Vec { + let mut result = Vec::new(); + if let Some(deps) = deps { + for dep in deps { + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &dep.root { + result.push(value.clone()); + } + } + } + result +} + +/// Get lvalue places from a Destructure pattern. +fn destructure_lvalue_places(pattern: &react_compiler_hir::Pattern) -> Vec<&Place> { + let mut result = Vec::new(); + match pattern { + react_compiler_hir::Pattern::Array(arr) => { + for item in &arr.items { + match item { + react_compiler_hir::ArrayPatternElement::Place(place) => { + result.push(place); + } + react_compiler_hir::ArrayPatternElement::Spread(spread) => { + result.push(&spread.place); + } + react_compiler_hir::ArrayPatternElement::Hole => {} + } + } + } + react_compiler_hir::Pattern::Object(obj) => { + for entry in &obj.properties { + match entry { + react_compiler_hir::ObjectPropertyOrSpread::Property(prop) => { + result.push(&prop.place); + } + react_compiler_hir::ObjectPropertyOrSpread::Spread(spread) => { + result.push(&spread.place); + } + } + } + } + } + result +} + +/// Check if an identifier is unmemoized (has a scope that hasn't completed). +fn is_unmemoized( + id: IdentifierId, + completed_scopes: &HashSet, + identifiers: &[Identifier], +) -> bool { + let ident = &identifiers[id.0 as usize]; + if let Some(scope_id) = ident.scope { + !completed_scopes.contains(&scope_id) + } else { + false + } +} + +// ============================================================================= +// Dependency comparison (port of compareDeps / validateInferredDep) +// ============================================================================= + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum CompareDependencyResult { + Ok = 0, + RootDifference = 1, + PathDifference = 2, + Subpath = 3, + RefAccessDifference = 4, +} + +fn compare_deps( + inferred: &ManualMemoDependency, + source: &ManualMemoDependency, +) -> CompareDependencyResult { + let roots_equal = match (&inferred.root, &source.root) { + ( + ManualMemoDependencyRoot::Global { identifier_name: a }, + ManualMemoDependencyRoot::Global { identifier_name: b }, + ) => a == b, + ( + ManualMemoDependencyRoot::NamedLocal { value: a, .. }, + ManualMemoDependencyRoot::NamedLocal { value: b, .. }, + ) => a.identifier == b.identifier, + _ => false, + }; + if !roots_equal { + return CompareDependencyResult::RootDifference; + } + + let min_len = inferred.path.len().min(source.path.len()); + let mut is_subpath = true; + for i in 0..min_len { + if inferred.path[i].property != source.path[i].property { + is_subpath = false; + break; + } else if inferred.path[i].optional != source.path[i].optional { + return CompareDependencyResult::PathDifference; + } + } + + if is_subpath + && (source.path.len() == inferred.path.len() + || (inferred.path.len() >= source.path.len() + && !inferred.path.iter().any(|t| { + t.property == react_compiler_hir::PropertyLiteral::String("current".to_string()) + }))) + { + CompareDependencyResult::Ok + } else if is_subpath { + if source.path.iter().any(|t| { + t.property == react_compiler_hir::PropertyLiteral::String("current".to_string()) + }) || inferred.path.iter().any(|t| { + t.property == react_compiler_hir::PropertyLiteral::String("current".to_string()) + }) { + CompareDependencyResult::RefAccessDifference + } else { + CompareDependencyResult::Subpath + } + } else { + CompareDependencyResult::PathDifference + } +} + +/// Pretty-print a reactive scope dependency (e.g., `x.a.b?.c`) +fn pretty_print_scope_dependency( + dep_id: IdentifierId, + dep_path: &[DependencyPathEntry], + identifiers: &[react_compiler_hir::Identifier], +) -> String { + let ident = &identifiers[dep_id.0 as usize]; + let root_str = match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => n.clone(), + Some(react_compiler_hir::IdentifierName::Promoted(n)) => n.clone(), + None => "[unnamed]".to_string(), + }; + let path_str: String = dep_path + .iter() + .map(|entry| { + let prop = match &entry.property { + react_compiler_hir::PropertyLiteral::String(s) => s.clone(), + react_compiler_hir::PropertyLiteral::Number(n) => format!("{}", n.value()), + }; + if entry.optional { + format!("?.{}", prop) + } else { + format!(".{}", prop) + } + }) + .collect(); + format!("{}{}", root_str, path_str) +} + +/// Pretty-print a manual memo dependency for error messages. +fn print_manual_memo_dependency( + dep: &ManualMemoDependency, + identifiers: &[react_compiler_hir::Identifier], + with_optional: bool, +) -> String { + let root_str = match &dep.root { + ManualMemoDependencyRoot::NamedLocal { value, .. } => { + let ident = &identifiers[value.identifier.0 as usize]; + match &ident.name { + Some(react_compiler_hir::IdentifierName::Named(n)) => n.clone(), + Some(react_compiler_hir::IdentifierName::Promoted(n)) => n.clone(), + None => "[unnamed]".to_string(), + } + } + ManualMemoDependencyRoot::Global { identifier_name } => identifier_name.clone(), + }; + let path_str: String = dep + .path + .iter() + .map(|entry| { + let prop = match &entry.property { + react_compiler_hir::PropertyLiteral::String(s) => s.clone(), + react_compiler_hir::PropertyLiteral::Number(n) => format!("{}", n.value()), + }; + if with_optional && entry.optional { + format!("?.{}", prop) + } else { + format!(".{}", prop) + } + }) + .collect(); + format!("{}{}", root_str, path_str) +} + +fn get_compare_dependency_result_description(result: CompareDependencyResult) -> &'static str { + match result { + CompareDependencyResult::Ok => "Dependencies equal", + CompareDependencyResult::RootDifference | CompareDependencyResult::PathDifference => { + "Inferred different dependency than source" + } + CompareDependencyResult::RefAccessDifference => "Differences in ref.current access", + CompareDependencyResult::Subpath => "Inferred less specific property than source", + } +} + +/// Validate that an inferred dependency matches a source dependency or was +/// produced within the manual memo block. +fn validate_inferred_dep( + dep_id: IdentifierId, + dep_path: &[DependencyPathEntry], + temporaries: &HashMap, + decls_within_memo_block: &HashSet, + valid_deps_in_memo_block: &[ManualMemoDependency], + env: &mut Environment, + memo_location: Option, +) { + // Normalize the dependency through temporaries + let normalized_dep = if let Some(temp) = temporaries.get(&dep_id) { + let mut path = temp.path.clone(); + path.extend_from_slice(dep_path); + ManualMemoDependency { + root: temp.root.clone(), + path, + loc: temp.loc, + } + } else { + let ident = &env.identifiers[dep_id.0 as usize]; + // TS: CompilerError.invariant(dep.identifier.name?.kind === 'named', ...) + assert!( + is_named(ident), + "ValidatePreservedManualMemoization: expected scope dependency to be named" + ); + ManualMemoDependency { + root: ManualMemoDependencyRoot::NamedLocal { + value: Place { + identifier: dep_id, + effect: react_compiler_hir::Effect::Read, + reactive: false, + loc: ident.loc, + }, + constant: false, + }, + path: dep_path.to_vec(), + loc: ident.loc, + } + }; + + // Check if the dep was declared within the memo block + if let ManualMemoDependencyRoot::NamedLocal { value, .. } = &normalized_dep.root { + let ident = &env.identifiers[value.identifier.0 as usize]; + if decls_within_memo_block.contains(&ident.declaration_id) { + return; + } + } + + // Compare against each valid source dependency + let mut error_diagnostic: Option = None; + for source_dep in valid_deps_in_memo_block { + let result = compare_deps(&normalized_dep, source_dep); + if result == CompareDependencyResult::Ok { + return; + } + error_diagnostic = Some(match error_diagnostic { + Some(prev) => prev.max(result), + None => result, + }); + } + + let ident = &env.identifiers[dep_id.0 as usize]; + + let extra = if is_named(ident) { + // Use the original dep_id/dep_path (matching TS + // prettyPrintScopeDependency(dep)) + let dep_str = pretty_print_scope_dependency(dep_id, dep_path, &env.identifiers); + let source_deps_str: String = valid_deps_in_memo_block + .iter() + .map(|d| print_manual_memo_dependency(d, &env.identifiers, true)) + .collect::>() + .join(", "); + let result_desc = error_diagnostic + .map(|d| get_compare_dependency_result_description(d).to_string()) + .unwrap_or_else(|| "Inferred dependency not present in source".to_string()); + format!( + "The inferred dependency was `{}`, but the source dependencies were [{}]. {}", + dep_str, source_deps_str, result_desc + ) + } else { + String::new() + }; + + let description = format!( + "React Compiler has skipped optimizing this component because the existing manual \ + memoization could not be preserved. The inferred dependencies did not match the manually \ + specified dependencies, which could cause the value to change more or less frequently \ + than expected. {}", + extra + ); + + let diag = CompilerDiagnostic::new( + ErrorCategory::PreserveManualMemo, + "Existing memoization could not be preserved", + Some(description.trim().to_string()), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: memo_location, + message: Some("Could not preserve existing manual memoization".to_string()), + identifier_name: None, + }); + env.record_diagnostic(diag); +} diff --git a/crates/react_compiler_validation/src/validate_static_components.rs b/crates/react_compiler_validation/src/validate_static_components.rs new file mode 100644 index 000000000000..55cc4973784f --- /dev/null +++ b/crates/react_compiler_validation/src/validate_static_components.rs @@ -0,0 +1,107 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Validates against components that are created dynamically and whose identity +//! is not guaranteed to be stable (which would cause the component to reset on +//! each re-render). +//! +//! Port of ValidateStaticComponents.ts. + +use std::collections::HashMap; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{HirFunction, IdentifierId, InstructionValue, JsxTag}; + +/// Validates that components used in JSX are not dynamically created during +/// render. +/// +/// Returns a CompilerError containing all diagnostics found (may be empty). +/// Called via `env.logErrors()` pattern in Pipeline.ts. +pub fn validate_static_components(func: &HirFunction) -> CompilerError { + let mut error = CompilerError::new(); + let mut known_dynamic_components: HashMap> = + HashMap::new(); + + for (_block_id, block) in &func.body.blocks { + // Process phis: propagate dynamic component knowledge through phi nodes + 'phis: for phi in &block.phis { + for (_pred, operand) in &phi.operands { + if let Some(loc) = known_dynamic_components.get(&operand.identifier) { + known_dynamic_components.insert(phi.place.identifier, *loc); + continue 'phis; + } + } + } + + // Process instructions + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue_id = instr.lvalue.identifier; + let value = &instr.value; + + match value { + InstructionValue::FunctionExpression { loc, .. } + | InstructionValue::NewExpression { loc, .. } + | InstructionValue::MethodCall { loc, .. } + | InstructionValue::CallExpression { loc, .. } => { + known_dynamic_components.insert(lvalue_id, *loc); + } + InstructionValue::LoadLocal { place, .. } => { + if let Some(loc) = known_dynamic_components.get(&place.identifier) { + known_dynamic_components.insert(lvalue_id, *loc); + } + } + InstructionValue::StoreLocal { + lvalue, value: val, .. + } => { + if let Some(loc) = known_dynamic_components.get(&val.identifier) { + let loc = *loc; + known_dynamic_components.insert(lvalue_id, loc); + known_dynamic_components.insert(lvalue.place.identifier, loc); + } + } + InstructionValue::JsxExpression { tag, .. } => { + if let JsxTag::Place(tag_place) = tag { + if let Some(location) = known_dynamic_components.get(&tag_place.identifier) + { + let location = *location; + let diagnostic = CompilerDiagnostic::new( + ErrorCategory::StaticComponents, + "Cannot create components during render", + Some( + "Components created during render will reset their state each \ + time they are created. Declare components outside of render" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: tag_place.loc, + message: Some( + "This component is created during render".to_string(), + ), + identifier_name: None, + }) + .with_detail( + CompilerDiagnosticDetail::Error { + loc: location, + message: Some( + "The component is created during render here".to_string(), + ), + identifier_name: None, + }, + ); + error.push_diagnostic(diagnostic); + } + } + } + _ => {} + } + } + } + + error +} diff --git a/crates/react_compiler_validation/src/validate_use_memo.rs b/crates/react_compiler_validation/src/validate_use_memo.rs new file mode 100644 index 000000000000..8e2078dbf5dc --- /dev/null +++ b/crates/react_compiler_validation/src/validate_use_memo.rs @@ -0,0 +1,326 @@ +use std::collections::{HashMap, HashSet}; + +use react_compiler_diagnostics::{ + CompilerDiagnostic, CompilerDiagnosticDetail, CompilerError, ErrorCategory, SourceLocation, +}; +use react_compiler_hir::{ + environment::Environment, + visitors::{each_instruction_value_operand_with_functions, each_terminal_operand}, + FunctionId, HirFunction, IdentifierId, InstructionValue, ParamPattern, Place, PlaceOrSpread, + ReturnVariant, Terminal, +}; + +/// Validates useMemo() usage patterns. +/// +/// Port of ValidateUseMemo.ts. +/// Returns VoidUseMemo errors separately (for logging via logErrors, not as +/// compile errors). +pub fn validate_use_memo(func: &HirFunction, env: &mut Environment) -> CompilerError { + validate_use_memo_impl( + func, + &env.functions, + &mut env.errors, + env.config.validate_no_void_use_memo, + ) +} + +/// Information about a FunctionExpression needed for validation. +struct FuncExprInfo { + func_id: FunctionId, + loc: Option, +} + +fn validate_use_memo_impl( + func: &HirFunction, + functions: &[HirFunction], + errors: &mut CompilerError, + validate_no_void_use_memo: bool, +) -> CompilerError { + let mut void_memo_errors = CompilerError::new(); + let mut use_memos: HashSet = HashSet::new(); + let mut react: HashSet = HashSet::new(); + let mut func_exprs: HashMap = HashMap::new(); + let mut unused_use_memos: HashMap)> = + HashMap::new(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + let lvalue = &instr.lvalue; + let value = &instr.value; + + // Remove used operands from unused_use_memos + if !unused_use_memos.is_empty() { + for operand_id in each_instruction_value_operand_ids(value, functions) { + unused_use_memos.remove(&operand_id); + } + } + + match value { + InstructionValue::LoadGlobal { binding, .. } => { + let name = binding.name(); + if name == "useMemo" { + use_memos.insert(lvalue.identifier); + } else if name == "React" { + react.insert(lvalue.identifier); + } + } + InstructionValue::PropertyLoad { + object, property, .. + } => { + if react.contains(&object.identifier) { + if let react_compiler_hir::PropertyLiteral::String(prop_name) = property { + if prop_name == "useMemo" { + use_memos.insert(lvalue.identifier); + } + } + } + } + InstructionValue::FunctionExpression { + lowered_func, loc, .. + } => { + func_exprs.insert( + lvalue.identifier, + FuncExprInfo { + func_id: lowered_func.func, + loc: *loc, + }, + ); + } + InstructionValue::CallExpression { callee, args, .. } => { + handle_possible_use_memo_call( + functions, + errors, + &mut void_memo_errors, + &use_memos, + &func_exprs, + &mut unused_use_memos, + callee, + args, + lvalue, + validate_no_void_use_memo, + ); + } + InstructionValue::MethodCall { property, args, .. } => { + handle_possible_use_memo_call( + functions, + errors, + &mut void_memo_errors, + &use_memos, + &func_exprs, + &mut unused_use_memos, + property, + args, + lvalue, + validate_no_void_use_memo, + ); + } + _ => {} + } + } + + // Check terminal operands for unused_use_memos + if !unused_use_memos.is_empty() { + for operand_id in each_terminal_operand_ids(&block.terminal) { + unused_use_memos.remove(&operand_id); + } + } + } + + // Report unused useMemo results + if !unused_use_memos.is_empty() { + for (loc, ident_name) in unused_use_memos.values() { + void_memo_errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::VoidUseMemo, + "useMemo() result is unused", + Some( + "This useMemo() value is unused. useMemo() is for computing and caching \ + values, not for arbitrary side effects" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: Some(*loc), + message: Some("useMemo() result is unused".to_string()), + identifier_name: ident_name.clone(), + }), + ); + } + } + + void_memo_errors +} + +#[allow(clippy::too_many_arguments)] +fn handle_possible_use_memo_call( + functions: &[HirFunction], + errors: &mut CompilerError, + void_memo_errors: &mut CompilerError, + use_memos: &HashSet, + func_exprs: &HashMap, + unused_use_memos: &mut HashMap)>, + callee: &Place, + args: &[PlaceOrSpread], + lvalue: &Place, + validate_no_void_use_memo: bool, +) { + let is_use_memo = use_memos.contains(&callee.identifier); + if !is_use_memo || args.is_empty() { + return; + } + + let first_arg = match &args[0] { + PlaceOrSpread::Place(place) => place, + PlaceOrSpread::Spread(_) => return, + }; + + let body_info = match func_exprs.get(&first_arg.identifier) { + Some(info) => info, + None => return, + }; + + let body_func = &functions[body_info.func_id.0 as usize]; + + // Validate no parameters + if !body_func.params.is_empty() { + let first_param = &body_func.params[0]; + let loc = match first_param { + ParamPattern::Place(place) => place.loc, + ParamPattern::Spread(spread) => spread.place.loc, + }; + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::UseMemo, + "useMemo() callbacks may not accept parameters", + Some( + "useMemo() callbacks are called by React to cache calculations across \ + re-renders. They should not take parameters. Instead, directly reference the \ + props, state, or local variables needed for the computation" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc, + message: Some("Callbacks with parameters are not supported".to_string()), + identifier_name: None, + }), + ); + } + + // Validate not async or generator + if body_func.is_async || body_func.generator { + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::UseMemo, + "useMemo() callbacks may not be async or generator functions", + Some( + "useMemo() callbacks are called once and must synchronously return a value" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: body_info.loc, + message: Some("Async and generator functions are not supported".to_string()), + identifier_name: None, + }), + ); + } + + // Validate no context variable assignment + validate_no_context_variable_assignment(body_func, errors); + + if validate_no_void_use_memo && !has_non_void_return(body_func) { + void_memo_errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::VoidUseMemo, + "useMemo() callbacks must return a value", + Some( + "This useMemo() callback doesn't return a value. useMemo() is for computing \ + and caching values, not for arbitrary side effects" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: body_info.loc, + message: Some("useMemo() callbacks must return a value".to_string()), + identifier_name: None, + }), + ); + } else if validate_no_void_use_memo { + if let Some(callee_loc) = callee.loc { + // The callee is always useMemo/React.useMemo since we checked is_use_memo + // above. The identifierName in Babel's AST SourceLocation is + // "useMemo". + unused_use_memos.insert(lvalue.identifier, (callee_loc, Some("useMemo".to_string()))); + } + } +} + +fn validate_no_context_variable_assignment(func: &HirFunction, errors: &mut CompilerError) { + let context: HashSet = + func.context.iter().map(|place| place.identifier).collect(); + + for (_block_id, block) in &func.body.blocks { + for &instr_id in &block.instructions { + let instr = &func.instructions[instr_id.0 as usize]; + if let InstructionValue::StoreContext { lvalue, .. } = &instr.value { + if context.contains(&lvalue.place.identifier) { + errors.push_diagnostic( + CompilerDiagnostic::new( + ErrorCategory::UseMemo, + "useMemo() callbacks may not reassign variables declared outside of \ + the callback", + Some( + "useMemo() callbacks must be pure functions and cannot reassign \ + variables defined outside of the callback function" + .to_string(), + ), + ) + .with_detail(CompilerDiagnosticDetail::Error { + loc: lvalue.place.loc, + message: Some("Cannot reassign variable".to_string()), + identifier_name: None, + }), + ); + } + } + } + } +} + +fn has_non_void_return(func: &HirFunction) -> bool { + for (_block_id, block) in &func.body.blocks { + if let Terminal::Return { return_variant, .. } = &block.terminal { + if matches!( + return_variant, + ReturnVariant::Explicit | ReturnVariant::Implicit + ) { + return true; + } + } + } + false +} + +/// Collect all operand IdentifierIds from an InstructionValue. +/// Thin wrapper around canonical +/// `each_instruction_value_operand_with_functions` that maps to ids. +fn each_instruction_value_operand_ids( + value: &InstructionValue, + functions: &[HirFunction], +) -> Vec { + each_instruction_value_operand_with_functions(value, functions) + .into_iter() + .map(|p| p.identifier) + .collect() +} + +/// Collect all operand IdentifierIds from a Terminal. +/// Thin wrapper around canonical `each_terminal_operand` that maps to ids. +fn each_terminal_operand_ids(terminal: &Terminal) -> Vec { + each_terminal_operand(terminal) + .into_iter() + .map(|p| p.identifier) + .collect() +} diff --git a/crates/swc_ecma_react_compiler/Cargo.toml b/crates/swc_ecma_react_compiler/Cargo.toml index e60c63e02574..a3ab6391e266 100644 --- a/crates/swc_ecma_react_compiler/Cargo.toml +++ b/crates/swc_ecma_react_compiler/Cargo.toml @@ -1,5 +1,5 @@ [package] -description = "SWC helpers for the React Compiler" +description = "SWC adapter for the vendored React Compiler from facebook/react#36173" documentation = "https://rustdoc.swc.rs/swc_ecma_react_compiler/" edition = { workspace = true } license = { workspace = true } @@ -13,8 +13,19 @@ all-features = true rustdoc-args = ["--cfg", "docsrs"] [dependencies] -swc_ecma_ast = { version = "23.0.0", path = "../swc_ecma_ast" } -swc_ecma_visit = { version = "23.0.0", path = "../swc_ecma_visit" } +indexmap = { workspace = true, features = ["serde"] } +react_compiler = { path = "../react_compiler" } +react_compiler_ast = { path = "../react_compiler_ast" } +react_compiler_diagnostics = { path = "../react_compiler_diagnostics" } +react_compiler_hir = { path = "../react_compiler_hir" } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +swc_atoms = { version = "9.0.0", path = "../swc_atoms" } +swc_common = { version = "21.0.0", path = "../swc_common" } +swc_ecma_ast = { version = "23.0.0", path = "../swc_ecma_ast" } +swc_ecma_codegen = { version = "26.0.0", path = "../swc_ecma_codegen" } +swc_ecma_parser = { version = "38.0.0", path = "../swc_ecma_parser" } +swc_ecma_visit = { version = "23.0.0", path = "../swc_ecma_visit" } [dev-dependencies] swc_common = { version = "21.0.0", path = "../swc_common" } diff --git a/crates/swc_ecma_react_compiler/src/convert_ast.rs b/crates/swc_ecma_react_compiler/src/convert_ast.rs new file mode 100644 index 000000000000..8368086998fe --- /dev/null +++ b/crates/swc_ecma_react_compiler/src/convert_ast.rs @@ -0,0 +1,2173 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use react_compiler_ast::{ + common::{BaseNode, Position, SourceLocation}, + declarations::*, + expressions::*, + jsx::*, + literals::*, + operators::*, + patterns::*, + statements::*, + File, Program, SourceType, +}; +use swc_common::{Span, Spanned}; +use swc_ecma_ast as swc; + +/// Helper to convert SWC's Wtf8Atom (which doesn't impl Display) to a String. +fn wtf8_to_string(value: &swc_atoms::Wtf8Atom) -> String { + value.to_string_lossy().into_owned() +} + +/// Converts an SWC Module AST to the React compiler's Babel-compatible AST. +pub fn convert_module(module: &swc::Module, source_text: &str) -> File { + convert_module_with_source_type(module, source_text, SourceType::Module) +} + +/// Converts an SWC Module AST to the React compiler's Babel-compatible AST +/// with an explicit source type. +pub fn convert_module_with_source_type( + module: &swc::Module, + source_text: &str, + source_type: SourceType, +) -> File { + let ctx = ConvertCtx::new(source_text); + let base = ctx.make_base_node(module.span); + + let mut body: Vec = Vec::new(); + let mut directives: Vec = Vec::new(); + let mut past_directives = false; + + for item in &module.body { + if !past_directives { + if let Some(dir) = try_extract_directive(item, &ctx) { + directives.push(dir); + continue; + } + past_directives = true; + } + body.push(ctx.convert_module_item(item)); + } + + // Extract comments from source text for suppression detection + let comments = extract_comments_from_source(source_text); + + File { + base: ctx.make_base_node(module.span), + program: Program { + base, + body, + directives, + source_type, + interpreter: None, + source_file: None, + }, + comments, + errors: vec![], + } +} + +/// Extract comments from source text for suppression detection. +/// This uses simple regex-style parsing to find block and line comments. +fn extract_comments_from_source(source: &str) -> Vec { + use react_compiler_ast::common::{Comment, CommentData}; + let mut comments = Vec::new(); + let bytes = source.as_bytes(); + let len = bytes.len(); + let mut i = 0; + + while i < len { + if bytes[i] == b'/' && i + 1 < len { + if bytes[i + 1] == b'/' { + // Line comment + let start = i as u32; + let content_start = i + 2; + let mut end = content_start; + while end < len && bytes[end] != b'\n' { + end += 1; + } + let value = String::from_utf8_lossy(&bytes[content_start..end]).to_string(); + comments.push(Comment::CommentLine(CommentData { + value: value.trim().to_string(), + start: Some(start), + end: Some(end as u32), + loc: None, + })); + i = end; + continue; + } else if bytes[i + 1] == b'*' { + // Block comment + let start = i as u32; + let content_start = i + 2; + let mut end = content_start; + while end + 1 < len { + if bytes[end] == b'*' && bytes[end + 1] == b'/' { + break; + } + end += 1; + } + let value = String::from_utf8_lossy(&bytes[content_start..end]).to_string(); + let comment_end = if end + 1 < len { end + 2 } else { end }; + comments.push(Comment::CommentBlock(CommentData { + value: value.trim().to_string(), + start: Some(start), + end: Some(comment_end as u32), + loc: None, + })); + i = comment_end; + continue; + } + } + // Skip string literals to avoid matching // inside strings + if bytes[i] == b'"' || bytes[i] == b'\'' || bytes[i] == b'`' { + let quote = bytes[i]; + i += 1; + while i < len { + if bytes[i] == b'\\' { + i += 2; // skip escaped char + continue; + } + if bytes[i] == quote { + break; + } + i += 1; + } + } + i += 1; + } + + comments +} + +fn try_extract_directive(item: &swc::ModuleItem, ctx: &ConvertCtx) -> Option { + if let swc::ModuleItem::Stmt(swc::Stmt::Expr(expr_stmt)) = item { + if let swc::Expr::Lit(swc::Lit::Str(s)) = &*expr_stmt.expr { + return Some(Directive { + base: ctx.make_base_node(expr_stmt.span), + value: DirectiveLiteral { + base: ctx.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }, + }); + } + } + None +} + +struct ConvertCtx<'a> { + #[allow(dead_code)] + source_text: &'a str, + line_offsets: Vec, +} + +impl<'a> ConvertCtx<'a> { + fn new(source_text: &'a str) -> Self { + let mut line_offsets = vec![0u32]; + for (i, ch) in source_text.char_indices() { + if ch == '\n' { + line_offsets.push((i + 1) as u32); + } + } + Self { + source_text, + line_offsets, + } + } + + fn make_base_node(&self, span: Span) -> BaseNode { + BaseNode { + node_type: None, + start: Some(span.lo.0), + end: Some(span.hi.0), + loc: Some(self.source_location(span)), + range: None, + extra: None, + leading_comments: None, + inner_comments: None, + trailing_comments: None, + } + } + + fn position(&self, offset: u32) -> Position { + let line_idx = match self.line_offsets.binary_search(&offset) { + Ok(idx) => idx, + Err(idx) => idx.saturating_sub(1), + }; + let line_start = self.line_offsets[line_idx]; + Position { + line: (line_idx as u32) + 1, + column: offset - line_start, + index: Some(offset), + } + } + + fn source_location(&self, span: Span) -> SourceLocation { + SourceLocation { + start: self.position(span.lo.0), + end: self.position(span.hi.0), + filename: None, + identifier_name: None, + } + } + + fn convert_module_item(&self, item: &swc::ModuleItem) -> Statement { + match item { + swc::ModuleItem::Stmt(stmt) => self.convert_statement(stmt), + swc::ModuleItem::ModuleDecl(decl) => self.convert_module_decl(decl), + } + } + + fn convert_module_decl(&self, decl: &swc::ModuleDecl) -> Statement { + match decl { + swc::ModuleDecl::Import(d) => { + Statement::ImportDeclaration(self.convert_import_declaration(d)) + } + swc::ModuleDecl::ExportDecl(d) => { + Statement::ExportNamedDeclaration(self.convert_export_decl(d)) + } + swc::ModuleDecl::ExportNamed(d) => { + Statement::ExportNamedDeclaration(self.convert_export_named(d)) + } + swc::ModuleDecl::ExportDefaultDecl(d) => { + Statement::ExportDefaultDeclaration(self.convert_export_default_decl(d)) + } + swc::ModuleDecl::ExportDefaultExpr(d) => { + Statement::ExportDefaultDeclaration(self.convert_export_default_expr(d)) + } + swc::ModuleDecl::ExportAll(d) => { + Statement::ExportAllDeclaration(self.convert_export_all(d)) + } + swc::ModuleDecl::TsImportEquals(d) => Statement::EmptyStatement(EmptyStatement { + base: self.make_base_node(d.span), + }), + swc::ModuleDecl::TsExportAssignment(d) => Statement::EmptyStatement(EmptyStatement { + base: self.make_base_node(d.span), + }), + swc::ModuleDecl::TsNamespaceExport(d) => Statement::EmptyStatement(EmptyStatement { + base: self.make_base_node(d.span), + }), + } + } + + // ===== Statements ===== + + fn convert_statement(&self, stmt: &swc::Stmt) -> Statement { + match stmt { + swc::Stmt::Block(s) => Statement::BlockStatement(self.convert_block_statement(s)), + swc::Stmt::Break(s) => Statement::BreakStatement(BreakStatement { + base: self.make_base_node(s.span), + label: s + .label + .as_ref() + .map(|l| self.convert_ident_to_identifier(l)), + }), + swc::Stmt::Continue(s) => Statement::ContinueStatement(ContinueStatement { + base: self.make_base_node(s.span), + label: s + .label + .as_ref() + .map(|l| self.convert_ident_to_identifier(l)), + }), + swc::Stmt::Debugger(s) => Statement::DebuggerStatement(DebuggerStatement { + base: self.make_base_node(s.span), + }), + swc::Stmt::DoWhile(s) => Statement::DoWhileStatement(DoWhileStatement { + base: self.make_base_node(s.span), + test: Box::new(self.convert_expression(&s.test)), + body: Box::new(self.convert_statement(&s.body)), + }), + swc::Stmt::Empty(s) => Statement::EmptyStatement(EmptyStatement { + base: self.make_base_node(s.span), + }), + swc::Stmt::Expr(s) => Statement::ExpressionStatement(ExpressionStatement { + base: self.make_base_node(s.span), + expression: Box::new(self.convert_expression(&s.expr)), + }), + swc::Stmt::ForIn(s) => Statement::ForInStatement(ForInStatement { + base: self.make_base_node(s.span), + left: Box::new(self.convert_for_head(&s.left)), + right: Box::new(self.convert_expression(&s.right)), + body: Box::new(self.convert_statement(&s.body)), + }), + swc::Stmt::ForOf(s) => Statement::ForOfStatement(ForOfStatement { + base: self.make_base_node(s.span), + left: Box::new(self.convert_for_head(&s.left)), + right: Box::new(self.convert_expression(&s.right)), + body: Box::new(self.convert_statement(&s.body)), + is_await: s.is_await, + }), + swc::Stmt::For(s) => Statement::ForStatement(ForStatement { + base: self.make_base_node(s.span), + init: s + .init + .as_ref() + .map(|i| Box::new(self.convert_var_decl_or_expr_to_for_init(i))), + test: s + .test + .as_ref() + .map(|t| Box::new(self.convert_expression(t))), + update: s + .update + .as_ref() + .map(|u| Box::new(self.convert_expression(u))), + body: Box::new(self.convert_statement(&s.body)), + }), + swc::Stmt::If(s) => Statement::IfStatement(IfStatement { + base: self.make_base_node(s.span), + test: Box::new(self.convert_expression(&s.test)), + consequent: Box::new(self.convert_statement(&s.cons)), + alternate: s.alt.as_ref().map(|a| Box::new(self.convert_statement(a))), + }), + swc::Stmt::Labeled(s) => Statement::LabeledStatement(LabeledStatement { + base: self.make_base_node(s.span), + label: self.convert_ident_to_identifier(&s.label), + body: Box::new(self.convert_statement(&s.body)), + }), + swc::Stmt::Return(s) => Statement::ReturnStatement(ReturnStatement { + base: self.make_base_node(s.span), + argument: s.arg.as_ref().map(|a| Box::new(self.convert_expression(a))), + }), + swc::Stmt::Switch(s) => Statement::SwitchStatement(SwitchStatement { + base: self.make_base_node(s.span), + discriminant: Box::new(self.convert_expression(&s.discriminant)), + cases: s + .cases + .iter() + .map(|c| SwitchCase { + base: self.make_base_node(c.span), + test: c + .test + .as_ref() + .map(|t| Box::new(self.convert_expression(t))), + consequent: c.cons.iter().map(|s| self.convert_statement(s)).collect(), + }) + .collect(), + }), + swc::Stmt::Throw(s) => Statement::ThrowStatement(ThrowStatement { + base: self.make_base_node(s.span), + argument: Box::new(self.convert_expression(&s.arg)), + }), + swc::Stmt::Try(s) => Statement::TryStatement(TryStatement { + base: self.make_base_node(s.span), + block: self.convert_block_statement(&s.block), + handler: s.handler.as_ref().map(|h| self.convert_catch_clause(h)), + finalizer: s + .finalizer + .as_ref() + .map(|f| self.convert_block_statement(f)), + }), + swc::Stmt::While(s) => Statement::WhileStatement(WhileStatement { + base: self.make_base_node(s.span), + test: Box::new(self.convert_expression(&s.test)), + body: Box::new(self.convert_statement(&s.body)), + }), + swc::Stmt::With(s) => Statement::WithStatement(WithStatement { + base: self.make_base_node(s.span), + object: Box::new(self.convert_expression(&s.obj)), + body: Box::new(self.convert_statement(&s.body)), + }), + swc::Stmt::Decl(d) => self.convert_decl_to_statement(d), + } + } + + fn convert_decl_to_statement(&self, decl: &swc::Decl) -> Statement { + match decl { + swc::Decl::Var(v) => { + Statement::VariableDeclaration(self.convert_variable_declaration(v)) + } + swc::Decl::Fn(f) => Statement::FunctionDeclaration(self.convert_fn_decl(f)), + swc::Decl::Class(c) => Statement::ClassDeclaration(self.convert_class_decl(c)), + swc::Decl::TsTypeAlias(d) => { + Statement::TSTypeAliasDeclaration(self.convert_ts_type_alias(d)) + } + swc::Decl::TsInterface(d) => { + Statement::TSInterfaceDeclaration(self.convert_ts_interface(d)) + } + swc::Decl::TsEnum(d) => Statement::TSEnumDeclaration(self.convert_ts_enum(d)), + swc::Decl::TsModule(d) => Statement::TSModuleDeclaration(self.convert_ts_module(d)), + swc::Decl::Using(u) => Statement::VariableDeclaration(self.convert_using_decl(u)), + } + } + + fn convert_block_statement(&self, block: &swc::BlockStmt) -> BlockStatement { + let mut body: Vec = Vec::new(); + let mut directives: Vec = Vec::new(); + let mut past_directives = false; + + for stmt in &block.stmts { + if !past_directives { + if let Some(dir) = self.try_extract_block_directive(stmt) { + directives.push(dir); + continue; + } + past_directives = true; + } + body.push(self.convert_statement(stmt)); + } + + BlockStatement { + base: self.make_base_node(block.span), + body, + directives, + } + } + + /// Try to extract a directive from a statement in a block body. + /// Directives are expression statements whose expression is a string + /// literal. + fn try_extract_block_directive(&self, stmt: &swc::Stmt) -> Option { + if let swc::Stmt::Expr(expr_stmt) = stmt { + if let swc::Expr::Lit(swc::Lit::Str(s)) = &*expr_stmt.expr { + return Some(Directive { + base: self.make_base_node(expr_stmt.span), + value: DirectiveLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }, + }); + } + } + None + } + + fn convert_catch_clause(&self, clause: &swc::CatchClause) -> CatchClause { + CatchClause { + base: self.make_base_node(clause.span), + param: clause.param.as_ref().map(|p| self.convert_pat(p)), + body: self.convert_block_statement(&clause.body), + } + } + + fn convert_var_decl_or_expr_to_for_init(&self, init: &swc::VarDeclOrExpr) -> ForInit { + match init { + swc::VarDeclOrExpr::VarDecl(v) => { + ForInit::VariableDeclaration(self.convert_variable_declaration(v)) + } + swc::VarDeclOrExpr::Expr(e) => { + ForInit::Expression(Box::new(self.convert_expression(e))) + } + } + } + + fn convert_for_head(&self, head: &swc::ForHead) -> ForInOfLeft { + match head { + swc::ForHead::VarDecl(v) => { + ForInOfLeft::VariableDeclaration(self.convert_variable_declaration(v)) + } + swc::ForHead::Pat(p) => ForInOfLeft::Pattern(Box::new(self.convert_pat(p))), + swc::ForHead::UsingDecl(u) => { + ForInOfLeft::VariableDeclaration(self.convert_using_decl(u)) + } + } + } + + fn convert_variable_declaration(&self, decl: &swc::VarDecl) -> VariableDeclaration { + VariableDeclaration { + base: self.make_base_node(decl.span), + declarations: decl + .decls + .iter() + .map(|d| self.convert_variable_declarator(d)) + .collect(), + kind: match decl.kind { + swc::VarDeclKind::Var => VariableDeclarationKind::Var, + swc::VarDeclKind::Let => VariableDeclarationKind::Let, + swc::VarDeclKind::Const => VariableDeclarationKind::Const, + }, + declare: if decl.declare { Some(true) } else { None }, + } + } + + fn convert_using_decl(&self, decl: &swc::UsingDecl) -> VariableDeclaration { + VariableDeclaration { + base: self.make_base_node(decl.span), + declarations: decl + .decls + .iter() + .map(|d| self.convert_variable_declarator(d)) + .collect(), + kind: VariableDeclarationKind::Using, + declare: None, + } + } + + fn convert_variable_declarator(&self, d: &swc::VarDeclarator) -> VariableDeclarator { + VariableDeclarator { + base: self.make_base_node(d.span), + id: self.convert_pat(&d.name), + init: d + .init + .as_ref() + .map(|e| Box::new(self.convert_expression(e))), + definite: if d.definite { Some(true) } else { None }, + } + } + + // ===== Expressions ===== + + fn convert_expression(&self, expr: &swc::Expr) -> Expression { + match expr { + swc::Expr::Lit(lit) => self.convert_lit(lit), + swc::Expr::Ident(id) => Expression::Identifier(self.convert_ident_to_identifier(id)), + swc::Expr::This(t) => Expression::ThisExpression(ThisExpression { + base: self.make_base_node(t.span), + }), + swc::Expr::Array(arr) => { + Expression::ArrayExpression(self.convert_array_expression(arr)) + } + swc::Expr::Object(obj) => { + Expression::ObjectExpression(self.convert_object_expression(obj)) + } + swc::Expr::Fn(f) => Expression::FunctionExpression(self.convert_fn_expr(f)), + swc::Expr::Unary(un) => Expression::UnaryExpression(UnaryExpression { + base: self.make_base_node(un.span), + operator: self.convert_unary_operator(un.op), + prefix: true, + argument: Box::new(self.convert_expression(&un.arg)), + }), + swc::Expr::Update(up) => Expression::UpdateExpression(UpdateExpression { + base: self.make_base_node(up.span), + operator: self.convert_update_operator(up.op), + argument: Box::new(self.convert_expression(&up.arg)), + prefix: up.prefix, + }), + swc::Expr::Bin(bin) => { + if let Some(log_op) = self.try_convert_logical_operator(bin.op) { + Expression::LogicalExpression(LogicalExpression { + base: self.make_base_node(bin.span), + operator: log_op, + left: Box::new(self.convert_expression(&bin.left)), + right: Box::new(self.convert_expression(&bin.right)), + }) + } else { + Expression::BinaryExpression(BinaryExpression { + base: self.make_base_node(bin.span), + operator: self.convert_binary_operator(bin.op), + left: Box::new(self.convert_expression(&bin.left)), + right: Box::new(self.convert_expression(&bin.right)), + }) + } + } + swc::Expr::Assign(a) => { + Expression::AssignmentExpression(self.convert_assignment_expression(a)) + } + swc::Expr::Member(m) => Expression::MemberExpression(self.convert_member_expression(m)), + swc::Expr::SuperProp(sp) => { + let (property, computed) = self.convert_super_prop(&sp.prop); + Expression::MemberExpression(MemberExpression { + base: self.make_base_node(sp.span), + object: Box::new(Expression::Super(Super { + base: self.make_base_node(sp.obj.span), + })), + property: Box::new(property), + computed, + }) + } + swc::Expr::Cond(c) => Expression::ConditionalExpression(ConditionalExpression { + base: self.make_base_node(c.span), + test: Box::new(self.convert_expression(&c.test)), + consequent: Box::new(self.convert_expression(&c.cons)), + alternate: Box::new(self.convert_expression(&c.alt)), + }), + swc::Expr::Call(call) => Expression::CallExpression(self.convert_call_expression(call)), + swc::Expr::New(n) => Expression::NewExpression(NewExpression { + base: self.make_base_node(n.span), + callee: Box::new(self.convert_expression(&n.callee)), + arguments: n.args.as_ref().map_or_else(Vec::new, |args| { + args.iter() + .map(|a| self.convert_expr_or_spread(a)) + .collect() + }), + type_parameters: None, + type_arguments: None, + }), + swc::Expr::Seq(seq) => Expression::SequenceExpression(SequenceExpression { + base: self.make_base_node(seq.span), + expressions: seq + .exprs + .iter() + .map(|e| self.convert_expression(e)) + .collect(), + }), + swc::Expr::Arrow(arrow) => { + Expression::ArrowFunctionExpression(self.convert_arrow_function(arrow)) + } + swc::Expr::Class(class) => { + Expression::ClassExpression(self.convert_class_expression(class)) + } + swc::Expr::Yield(y) => Expression::YieldExpression(YieldExpression { + base: self.make_base_node(y.span), + argument: y.arg.as_ref().map(|a| Box::new(self.convert_expression(a))), + delegate: y.delegate, + }), + swc::Expr::Await(a) => Expression::AwaitExpression(AwaitExpression { + base: self.make_base_node(a.span), + argument: Box::new(self.convert_expression(&a.arg)), + }), + swc::Expr::MetaProp(mp) => { + let (meta_name, prop_name) = match mp.kind { + swc::MetaPropKind::NewTarget => ("new", "target"), + swc::MetaPropKind::ImportMeta => ("import", "meta"), + }; + Expression::MetaProperty(MetaProperty { + base: self.make_base_node(mp.span), + meta: Identifier { + base: self.make_base_node(mp.span), + name: meta_name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }, + property: Identifier { + base: self.make_base_node(mp.span), + name: prop_name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }, + }) + } + swc::Expr::Tpl(tpl) => Expression::TemplateLiteral(self.convert_template_literal(tpl)), + swc::Expr::TaggedTpl(tag) => { + Expression::TaggedTemplateExpression(TaggedTemplateExpression { + base: self.make_base_node(tag.span), + tag: Box::new(self.convert_expression(&tag.tag)), + quasi: self.convert_template_literal(&tag.tpl), + type_parameters: None, + }) + } + swc::Expr::Paren(p) => Expression::ParenthesizedExpression(ParenthesizedExpression { + base: self.make_base_node(p.span), + expression: Box::new(self.convert_expression(&p.expr)), + }), + swc::Expr::OptChain(chain) => self.convert_opt_chain_expression(chain), + swc::Expr::PrivateName(p) => Expression::PrivateName(PrivateName { + base: self.make_base_node(p.span), + id: Identifier { + base: self.make_base_node(p.span), + name: p.name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }, + }), + swc::Expr::JSXElement(el) => { + Expression::JSXElement(Box::new(self.convert_jsx_element(el))) + } + swc::Expr::JSXFragment(frag) => { + Expression::JSXFragment(self.convert_jsx_fragment(frag)) + } + swc::Expr::JSXEmpty(e) => Expression::Identifier(Identifier { + base: self.make_base_node(e.span), + name: "undefined".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + swc::Expr::JSXMember(m) => Expression::Identifier(Identifier { + base: self.make_base_node(m.prop.span), + name: m.prop.sym.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + swc::Expr::JSXNamespacedName(n) => Expression::Identifier(Identifier { + base: self.make_base_node(n.name.span), + name: format!("{}:{}", n.ns.sym, n.name.sym), + type_annotation: None, + optional: None, + decorators: None, + }), + swc::Expr::TsAs(e) => Expression::TSAsExpression(TSAsExpression { + base: self.make_base_node(e.span), + expression: Box::new(self.convert_expression(&e.expr)), + type_annotation: Box::new(serde_json::Value::Null), + }), + swc::Expr::TsSatisfies(e) => Expression::TSSatisfiesExpression(TSSatisfiesExpression { + base: self.make_base_node(e.span), + expression: Box::new(self.convert_expression(&e.expr)), + type_annotation: Box::new(serde_json::Value::Null), + }), + swc::Expr::TsTypeAssertion(e) => Expression::TSTypeAssertion(TSTypeAssertion { + base: self.make_base_node(e.span), + expression: Box::new(self.convert_expression(&e.expr)), + type_annotation: Box::new(serde_json::Value::Null), + }), + swc::Expr::TsNonNull(e) => Expression::TSNonNullExpression(TSNonNullExpression { + base: self.make_base_node(e.span), + expression: Box::new(self.convert_expression(&e.expr)), + }), + swc::Expr::TsInstantiation(e) => { + Expression::TSInstantiationExpression(TSInstantiationExpression { + base: self.make_base_node(e.span), + expression: Box::new(self.convert_expression(&e.expr)), + type_parameters: Box::new(serde_json::Value::Null), + }) + } + swc::Expr::TsConstAssertion(e) => { + // "as const" → TSAsExpression with typeAnnotation: TSTypeReference { typeName: + // Identifier { name: "const" } } This matches Babel's AST + // representation of `as const`. + let type_ann = serde_json::json!({ + "type": "TSTypeReference", + "typeName": { + "type": "Identifier", + "name": "const" + } + }); + Expression::TSAsExpression(TSAsExpression { + base: self.make_base_node(e.span), + expression: Box::new(self.convert_expression(&e.expr)), + type_annotation: Box::new(type_ann), + }) + } + swc::Expr::Invalid(i) => Expression::Identifier(Identifier { + base: self.make_base_node(i.span), + name: "__invalid__".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + } + } + + fn convert_lit(&self, lit: &swc::Lit) -> Expression { + match lit { + swc::Lit::Str(s) => Expression::StringLiteral(StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }), + swc::Lit::Bool(b) => Expression::BooleanLiteral(BooleanLiteral { + base: self.make_base_node(b.span), + value: b.value, + }), + swc::Lit::Null(n) => Expression::NullLiteral(NullLiteral { + base: self.make_base_node(n.span), + }), + swc::Lit::Num(n) => Expression::NumericLiteral(NumericLiteral { + base: self.make_base_node(n.span), + value: n.value, + }), + swc::Lit::BigInt(b) => Expression::BigIntLiteral(BigIntLiteral { + base: self.make_base_node(b.span), + value: b.value.to_string(), + }), + swc::Lit::Regex(r) => Expression::RegExpLiteral(RegExpLiteral { + base: self.make_base_node(r.span), + pattern: r.exp.to_string(), + flags: r.flags.to_string(), + }), + swc::Lit::JSXText(t) => Expression::StringLiteral(StringLiteral { + base: self.make_base_node(t.span), + value: t.value.to_string(), + }), + } + } + + // ===== Optional chaining ===== + + fn convert_opt_chain_expression(&self, chain: &swc::OptChainExpr) -> Expression { + match &*chain.base { + swc::OptChainBase::Member(m) => { + let (property, computed) = self.convert_member_prop(&m.prop); + Expression::OptionalMemberExpression(OptionalMemberExpression { + base: self.make_base_node(chain.span), + object: Box::new(self.convert_opt_chain_callee(&m.obj)), + property: Box::new(property), + computed, + optional: chain.optional, + }) + } + swc::OptChainBase::Call(call) => { + Expression::OptionalCallExpression(OptionalCallExpression { + base: self.make_base_node(chain.span), + callee: Box::new(self.convert_opt_chain_callee(&call.callee)), + arguments: call + .args + .iter() + .map(|a| self.convert_expr_or_spread(a)) + .collect(), + optional: chain.optional, + type_parameters: None, + type_arguments: None, + }) + } + } + } + + fn convert_opt_chain_callee(&self, expr: &swc::Expr) -> Expression { + if let swc::Expr::OptChain(chain) = expr { + return self.convert_opt_chain_expression(chain); + } + self.convert_expression(expr) + } + + // ===== Member expression ===== + + fn convert_member_expression(&self, m: &swc::MemberExpr) -> MemberExpression { + let (property, computed) = self.convert_member_prop(&m.prop); + MemberExpression { + base: self.make_base_node(m.span), + object: Box::new(self.convert_expression(&m.obj)), + property: Box::new(property), + computed, + } + } + + fn convert_member_prop(&self, prop: &swc::MemberProp) -> (Expression, bool) { + match prop { + swc::MemberProp::Ident(id) => ( + Expression::Identifier(Identifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + false, + ), + swc::MemberProp::Computed(c) => (self.convert_expression(&c.expr), true), + swc::MemberProp::PrivateName(p) => ( + Expression::PrivateName(PrivateName { + base: self.make_base_node(p.span), + id: Identifier { + base: self.make_base_node(p.span), + name: p.name.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }, + }), + false, + ), + } + } + + fn convert_super_prop(&self, prop: &swc::SuperProp) -> (Expression, bool) { + match prop { + swc::SuperProp::Ident(id) => ( + Expression::Identifier(Identifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + false, + ), + swc::SuperProp::Computed(c) => (self.convert_expression(&c.expr), true), + } + } + + // ===== Call expression ===== + + fn convert_call_expression(&self, call: &swc::CallExpr) -> CallExpression { + CallExpression { + base: self.make_base_node(call.span), + callee: Box::new(self.convert_callee(&call.callee)), + arguments: call + .args + .iter() + .map(|a| self.convert_expr_or_spread(a)) + .collect(), + type_parameters: None, + type_arguments: None, + optional: None, + } + } + + fn convert_callee(&self, callee: &swc::Callee) -> Expression { + match callee { + swc::Callee::Expr(e) => self.convert_expression(e), + swc::Callee::Super(s) => Expression::Super(Super { + base: self.make_base_node(s.span), + }), + swc::Callee::Import(i) => Expression::Import(Import { + base: self.make_base_node(i.span), + }), + } + } + + fn convert_expr_or_spread(&self, arg: &swc::ExprOrSpread) -> Expression { + if let Some(spread_span) = arg.spread { + Expression::SpreadElement(SpreadElement { + base: self.make_base_node(Span::new(spread_span.lo, arg.expr.span().hi)), + argument: Box::new(self.convert_expression(&arg.expr)), + }) + } else { + self.convert_expression(&arg.expr) + } + } + + // ===== Function helpers ===== + + fn convert_fn_decl(&self, func: &swc::FnDecl) -> FunctionDeclaration { + let f = &func.function; + let body = f + .body + .as_ref() + .map(|b| self.convert_block_statement(b)) + .unwrap_or_else(|| BlockStatement { + base: self.make_base_node(f.span), + body: vec![], + directives: vec![], + }); + FunctionDeclaration { + base: self.make_base_node(f.span), + id: Some(self.convert_ident_to_identifier(&func.ident)), + params: self.convert_params(&f.params), + body, + generator: f.is_generator, + is_async: f.is_async, + declare: if func.declare { Some(true) } else { None }, + return_type: f + .return_type + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + type_parameters: f + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + predicate: None, + component_declaration: false, + hook_declaration: false, + } + } + + fn convert_fn_expr(&self, func: &swc::FnExpr) -> FunctionExpression { + let f = &func.function; + let body = f + .body + .as_ref() + .map(|b| self.convert_block_statement(b)) + .unwrap_or_else(|| BlockStatement { + base: self.make_base_node(f.span), + body: vec![], + directives: vec![], + }); + FunctionExpression { + base: self.make_base_node(f.span), + id: func + .ident + .as_ref() + .map(|id| self.convert_ident_to_identifier(id)), + params: self.convert_params(&f.params), + body, + generator: f.is_generator, + is_async: f.is_async, + return_type: f + .return_type + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + type_parameters: f + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + } + } + + fn convert_arrow_function(&self, arrow: &swc::ArrowExpr) -> ArrowFunctionExpression { + let is_expression = matches!(&*arrow.body, swc::BlockStmtOrExpr::Expr(_)); + let body = match &*arrow.body { + swc::BlockStmtOrExpr::BlockStmt(block) => { + ArrowFunctionBody::BlockStatement(self.convert_block_statement(block)) + } + swc::BlockStmtOrExpr::Expr(expr) => { + ArrowFunctionBody::Expression(Box::new(self.convert_expression(expr))) + } + }; + ArrowFunctionExpression { + base: self.make_base_node(arrow.span), + params: arrow.params.iter().map(|p| self.convert_pat(p)).collect(), + body: Box::new(body), + id: None, + generator: arrow.is_generator, + is_async: arrow.is_async, + expression: Some(is_expression), + return_type: arrow + .return_type + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + type_parameters: arrow + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + predicate: None, + } + } + + fn convert_params(&self, params: &[swc::Param]) -> Vec { + params.iter().map(|p| self.convert_pat(&p.pat)).collect() + } + + // ===== Patterns ===== + + fn convert_pat(&self, pat: &swc::Pat) -> PatternLike { + match pat { + swc::Pat::Ident(id) => PatternLike::Identifier(self.convert_binding_ident(id)), + swc::Pat::Array(arr) => PatternLike::ArrayPattern(self.convert_array_pattern(arr)), + swc::Pat::Object(obj) => PatternLike::ObjectPattern(self.convert_object_pattern(obj)), + swc::Pat::Assign(a) => PatternLike::AssignmentPattern(AssignmentPattern { + base: self.make_base_node(a.span), + left: Box::new(self.convert_pat(&a.left)), + right: Box::new(self.convert_expression(&a.right)), + type_annotation: None, + decorators: None, + }), + swc::Pat::Rest(r) => PatternLike::RestElement(RestElement { + base: self.make_base_node(r.span), + argument: Box::new(self.convert_pat(&r.arg)), + type_annotation: None, + decorators: None, + }), + swc::Pat::Expr(e) => self.convert_expression_to_pattern(e), + swc::Pat::Invalid(i) => PatternLike::Identifier(Identifier { + base: self.make_base_node(i.span), + name: "__invalid__".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + } + } + + fn convert_expression_to_pattern(&self, expr: &swc::Expr) -> PatternLike { + match expr { + swc::Expr::Ident(id) => PatternLike::Identifier(self.convert_ident_to_identifier(id)), + swc::Expr::Member(m) => { + PatternLike::MemberExpression(self.convert_member_expression(m)) + } + _ => PatternLike::Identifier(Identifier { + base: self.make_base_node(expr.span()), + name: "__unknown_target__".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + } + } + + fn convert_object_pattern(&self, obj: &swc::ObjectPat) -> ObjectPattern { + let properties = obj + .props + .iter() + .map(|p| match p { + swc::ObjectPatProp::KeyValue(kv) => { + ObjectPatternProperty::ObjectProperty(ObjectPatternProp { + base: self.make_base_node(kv.span()), + key: Box::new(self.convert_prop_name(&kv.key)), + value: Box::new(self.convert_pat(&kv.value)), + computed: matches!(kv.key, swc::PropName::Computed(_)), + shorthand: false, + decorators: None, + method: None, + }) + } + swc::ObjectPatProp::Assign(a) => { + let id = self.convert_ident_to_identifier(&a.key.id); + let (value, shorthand) = if let Some(ref init) = a.value { + ( + Box::new(PatternLike::AssignmentPattern(AssignmentPattern { + base: self.make_base_node(a.span), + left: Box::new(PatternLike::Identifier(id.clone())), + right: Box::new(self.convert_expression(init)), + type_annotation: None, + decorators: None, + })), + true, + ) + } else { + (Box::new(PatternLike::Identifier(id.clone())), true) + }; + ObjectPatternProperty::ObjectProperty(ObjectPatternProp { + base: self.make_base_node(a.span), + key: Box::new(Expression::Identifier(id)), + value, + computed: false, + shorthand, + decorators: None, + method: None, + }) + } + swc::ObjectPatProp::Rest(r) => ObjectPatternProperty::RestElement(RestElement { + base: self.make_base_node(r.span), + argument: Box::new(self.convert_pat(&r.arg)), + type_annotation: None, + decorators: None, + }), + }) + .collect(); + ObjectPattern { + base: self.make_base_node(obj.span), + properties, + type_annotation: obj + .type_ann + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + decorators: None, + } + } + + fn convert_array_pattern(&self, arr: &swc::ArrayPat) -> ArrayPattern { + ArrayPattern { + base: self.make_base_node(arr.span), + elements: arr + .elems + .iter() + .map(|e| e.as_ref().map(|p| self.convert_pat(p))) + .collect(), + type_annotation: arr + .type_ann + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + decorators: None, + } + } + + // ===== AssignmentTarget ===== + + fn convert_assign_target(&self, target: &swc::AssignTarget) -> PatternLike { + match target { + swc::AssignTarget::Simple(s) => self.convert_simple_assign_target(s), + swc::AssignTarget::Pat(p) => self.convert_assign_target_pat(p), + } + } + + fn convert_simple_assign_target(&self, target: &swc::SimpleAssignTarget) -> PatternLike { + match target { + swc::SimpleAssignTarget::Ident(id) => { + PatternLike::Identifier(self.convert_binding_ident(id)) + } + swc::SimpleAssignTarget::Member(m) => { + PatternLike::MemberExpression(self.convert_member_expression(m)) + } + swc::SimpleAssignTarget::SuperProp(sp) => { + let (property, computed) = self.convert_super_prop(&sp.prop); + PatternLike::MemberExpression(MemberExpression { + base: self.make_base_node(sp.span), + object: Box::new(Expression::Super(Super { + base: self.make_base_node(sp.obj.span), + })), + property: Box::new(property), + computed, + }) + } + swc::SimpleAssignTarget::Paren(p) => self.convert_expression_to_pattern(&p.expr), + swc::SimpleAssignTarget::OptChain(o) => PatternLike::Identifier(Identifier { + base: self.make_base_node(o.span), + name: "__unknown_target__".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + swc::SimpleAssignTarget::TsAs(e) => self.convert_expression_to_pattern(&e.expr), + swc::SimpleAssignTarget::TsSatisfies(e) => self.convert_expression_to_pattern(&e.expr), + swc::SimpleAssignTarget::TsNonNull(e) => self.convert_expression_to_pattern(&e.expr), + swc::SimpleAssignTarget::TsTypeAssertion(e) => { + self.convert_expression_to_pattern(&e.expr) + } + swc::SimpleAssignTarget::TsInstantiation(e) => { + self.convert_expression_to_pattern(&e.expr) + } + swc::SimpleAssignTarget::Invalid(i) => PatternLike::Identifier(Identifier { + base: self.make_base_node(i.span), + name: "__invalid__".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + } + } + + fn convert_assign_target_pat(&self, target: &swc::AssignTargetPat) -> PatternLike { + match target { + swc::AssignTargetPat::Array(a) => { + PatternLike::ArrayPattern(self.convert_array_pattern(a)) + } + swc::AssignTargetPat::Object(o) => { + PatternLike::ObjectPattern(self.convert_object_pattern(o)) + } + swc::AssignTargetPat::Invalid(i) => PatternLike::Identifier(Identifier { + base: self.make_base_node(i.span), + name: "__invalid__".to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + } + } + + fn convert_assignment_expression(&self, assign: &swc::AssignExpr) -> AssignmentExpression { + AssignmentExpression { + base: self.make_base_node(assign.span), + operator: self.convert_assignment_operator(assign.op), + left: Box::new(self.convert_assign_target(&assign.left)), + right: Box::new(self.convert_expression(&assign.right)), + } + } + + // ===== Object expression ===== + + fn convert_object_expression(&self, obj: &swc::ObjectLit) -> ObjectExpression { + ObjectExpression { + base: self.make_base_node(obj.span), + properties: obj + .props + .iter() + .map(|p| self.convert_prop_or_spread(p)) + .collect(), + } + } + + fn convert_prop_or_spread(&self, prop: &swc::PropOrSpread) -> ObjectExpressionProperty { + match prop { + swc::PropOrSpread::Spread(s) => { + ObjectExpressionProperty::SpreadElement(SpreadElement { + base: self.make_base_node(s.span()), + argument: Box::new(self.convert_expression(&s.expr)), + }) + } + swc::PropOrSpread::Prop(p) => self.convert_prop(p), + } + } + + fn convert_prop(&self, prop: &swc::Prop) -> ObjectExpressionProperty { + match prop { + swc::Prop::Shorthand(id) => { + let ident = self.convert_ident_to_identifier(id); + ObjectExpressionProperty::ObjectProperty(ObjectProperty { + base: self.make_base_node(id.span), + key: Box::new(Expression::Identifier(ident.clone())), + value: Box::new(Expression::Identifier(ident)), + computed: false, + shorthand: true, + decorators: None, + method: Some(false), + }) + } + swc::Prop::KeyValue(kv) => ObjectExpressionProperty::ObjectProperty(ObjectProperty { + base: self.make_base_node(kv.span()), + key: Box::new(self.convert_prop_name(&kv.key)), + value: Box::new(self.convert_expression(&kv.value)), + computed: matches!(kv.key, swc::PropName::Computed(_)), + shorthand: false, + decorators: None, + method: Some(false), + }), + swc::Prop::Getter(g) => ObjectExpressionProperty::ObjectMethod(ObjectMethod { + base: self.make_base_node(g.span), + method: false, + kind: ObjectMethodKind::Get, + key: Box::new(self.convert_prop_name(&g.key)), + params: vec![], + body: g + .body + .as_ref() + .map(|b| self.convert_block_statement(b)) + .unwrap_or_else(|| BlockStatement { + base: self.make_base_node(g.span), + body: vec![], + directives: vec![], + }), + computed: matches!(g.key, swc::PropName::Computed(_)), + id: None, + generator: false, + is_async: false, + decorators: None, + return_type: g + .type_ann + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + type_parameters: None, + }), + swc::Prop::Setter(s) => ObjectExpressionProperty::ObjectMethod(ObjectMethod { + base: self.make_base_node(s.span), + method: false, + kind: ObjectMethodKind::Set, + key: Box::new(self.convert_prop_name(&s.key)), + params: vec![self.convert_pat(&s.param)], + body: s + .body + .as_ref() + .map(|b| self.convert_block_statement(b)) + .unwrap_or_else(|| BlockStatement { + base: self.make_base_node(s.span), + body: vec![], + directives: vec![], + }), + computed: matches!(s.key, swc::PropName::Computed(_)), + id: None, + generator: false, + is_async: false, + decorators: None, + return_type: None, + type_parameters: None, + }), + swc::Prop::Method(m) => ObjectExpressionProperty::ObjectMethod(ObjectMethod { + base: self.make_base_node(m.span()), + method: true, + kind: ObjectMethodKind::Method, + key: Box::new(self.convert_prop_name(&m.key)), + params: self.convert_params(&m.function.params), + body: m + .function + .body + .as_ref() + .map(|b| self.convert_block_statement(b)) + .unwrap_or_else(|| BlockStatement { + base: self.make_base_node(m.function.span), + body: vec![], + directives: vec![], + }), + computed: matches!(m.key, swc::PropName::Computed(_)), + id: None, + generator: m.function.is_generator, + is_async: m.function.is_async, + decorators: None, + return_type: m + .function + .return_type + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + type_parameters: m + .function + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + }), + swc::Prop::Assign(a) => { + let ident = self.convert_ident_to_identifier(&a.key); + ObjectExpressionProperty::ObjectProperty(ObjectProperty { + base: self.make_base_node(a.span), + key: Box::new(Expression::Identifier(ident.clone())), + value: Box::new(Expression::AssignmentExpression(AssignmentExpression { + base: self.make_base_node(a.span), + operator: AssignmentOperator::Assign, + left: Box::new(PatternLike::Identifier(ident)), + right: Box::new(self.convert_expression(&a.value)), + })), + computed: false, + shorthand: true, + decorators: None, + method: Some(false), + }) + } + } + } + + fn convert_array_expression(&self, arr: &swc::ArrayLit) -> ArrayExpression { + ArrayExpression { + base: self.make_base_node(arr.span), + elements: arr + .elems + .iter() + .map(|e| e.as_ref().map(|elem| self.convert_expr_or_spread(elem))) + .collect(), + } + } + + fn convert_template_literal(&self, tpl: &swc::Tpl) -> TemplateLiteral { + TemplateLiteral { + base: self.make_base_node(tpl.span), + quasis: tpl + .quasis + .iter() + .map(|q| TemplateElement { + base: self.make_base_node(q.span), + value: TemplateElementValue { + raw: q.raw.to_string(), + cooked: q.cooked.as_ref().map(wtf8_to_string), + }, + tail: q.tail, + }) + .collect(), + expressions: tpl + .exprs + .iter() + .map(|e| self.convert_expression(e)) + .collect(), + } + } + + // ===== Class ===== + + fn convert_class_decl(&self, class: &swc::ClassDecl) -> ClassDeclaration { + let c = &class.class; + ClassDeclaration { + base: self.make_base_node(c.span), + id: Some(self.convert_ident_to_identifier(&class.ident)), + super_class: c + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))), + body: ClassBody { + base: self.make_base_node(c.span), + body: vec![], + }, + decorators: None, + is_abstract: if c.is_abstract { Some(true) } else { None }, + declare: if class.declare { Some(true) } else { None }, + implements: None, + super_type_parameters: None, + type_parameters: c + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + mixins: None, + } + } + + fn convert_class_expression(&self, class: &swc::ClassExpr) -> ClassExpression { + let c = &class.class; + ClassExpression { + base: self.make_base_node(c.span), + id: class + .ident + .as_ref() + .map(|id| self.convert_ident_to_identifier(id)), + super_class: c + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))), + body: ClassBody { + base: self.make_base_node(c.span), + body: vec![], + }, + decorators: None, + implements: None, + super_type_parameters: None, + type_parameters: c + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + } + } + + // ===== JSX ===== + + fn convert_jsx_element(&self, el: &swc::JSXElement) -> JSXElement { + let self_closing = el.closing.is_none(); + JSXElement { + base: self.make_base_node(el.span), + opening_element: self.convert_jsx_opening_element(&el.opening, self_closing), + closing_element: el + .closing + .as_ref() + .map(|c| self.convert_jsx_closing_element(c)), + children: el + .children + .iter() + .map(|c| self.convert_jsx_child(c)) + .collect(), + self_closing: Some(self_closing), + } + } + + fn convert_jsx_opening_element( + &self, + el: &swc::JSXOpeningElement, + self_closing: bool, + ) -> JSXOpeningElement { + JSXOpeningElement { + base: self.make_base_node(el.span), + name: self.convert_jsx_element_name(&el.name), + attributes: el + .attrs + .iter() + .map(|a| self.convert_jsx_attr_or_spread(a)) + .collect(), + self_closing, + type_parameters: el + .type_args + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + } + } + + fn convert_jsx_closing_element(&self, el: &swc::JSXClosingElement) -> JSXClosingElement { + JSXClosingElement { + base: self.make_base_node(el.span), + name: self.convert_jsx_element_name(&el.name), + } + } + + fn convert_jsx_element_name(&self, name: &swc::JSXElementName) -> JSXElementName { + match name { + swc::JSXElementName::Ident(id) => JSXElementName::JSXIdentifier(JSXIdentifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + }), + swc::JSXElementName::JSXMemberExpr(m) => { + JSXElementName::JSXMemberExpression(self.convert_jsx_member_expression(m)) + } + swc::JSXElementName::JSXNamespacedName(ns) => { + JSXElementName::JSXNamespacedName(JSXNamespacedName { + base: self.make_base_node(ns.span()), + namespace: JSXIdentifier { + base: self.make_base_node(ns.ns.span), + name: ns.ns.sym.to_string(), + }, + name: JSXIdentifier { + base: self.make_base_node(ns.name.span), + name: ns.name.sym.to_string(), + }, + }) + } + } + } + + fn convert_jsx_member_expression(&self, m: &swc::JSXMemberExpr) -> JSXMemberExpression { + JSXMemberExpression { + base: self.make_base_node(m.span()), + object: Box::new(self.convert_jsx_object(&m.obj)), + property: JSXIdentifier { + base: self.make_base_node(m.prop.span), + name: m.prop.sym.to_string(), + }, + } + } + + fn convert_jsx_object(&self, obj: &swc::JSXObject) -> JSXMemberExprObject { + match obj { + swc::JSXObject::Ident(id) => JSXMemberExprObject::JSXIdentifier(JSXIdentifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + }), + swc::JSXObject::JSXMemberExpr(m) => JSXMemberExprObject::JSXMemberExpression(Box::new( + self.convert_jsx_member_expression(m), + )), + } + } + + fn convert_jsx_attr_or_spread(&self, attr: &swc::JSXAttrOrSpread) -> JSXAttributeItem { + match attr { + swc::JSXAttrOrSpread::JSXAttr(a) => { + JSXAttributeItem::JSXAttribute(self.convert_jsx_attribute(a)) + } + swc::JSXAttrOrSpread::SpreadElement(s) => { + JSXAttributeItem::JSXSpreadAttribute(JSXSpreadAttribute { + base: self.make_base_node(s.span()), + argument: Box::new(self.convert_expression(&s.expr)), + }) + } + } + } + + fn convert_jsx_attribute(&self, attr: &swc::JSXAttr) -> JSXAttribute { + JSXAttribute { + base: self.make_base_node(attr.span), + name: self.convert_jsx_attr_name(&attr.name), + value: attr.value.as_ref().map(|v| self.convert_jsx_attr_value(v)), + } + } + + fn convert_jsx_attr_name(&self, name: &swc::JSXAttrName) -> JSXAttributeName { + match name { + swc::JSXAttrName::Ident(id) => JSXAttributeName::JSXIdentifier(JSXIdentifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + }), + swc::JSXAttrName::JSXNamespacedName(ns) => { + JSXAttributeName::JSXNamespacedName(JSXNamespacedName { + base: self.make_base_node(ns.span()), + namespace: JSXIdentifier { + base: self.make_base_node(ns.ns.span), + name: ns.ns.sym.to_string(), + }, + name: JSXIdentifier { + base: self.make_base_node(ns.name.span), + name: ns.name.sym.to_string(), + }, + }) + } + } + } + + fn convert_jsx_attr_value(&self, value: &swc::JSXAttrValue) -> JSXAttributeValue { + match value { + swc::JSXAttrValue::Str(s) => JSXAttributeValue::StringLiteral(StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }), + swc::JSXAttrValue::JSXExprContainer(ec) => { + JSXAttributeValue::JSXExpressionContainer(self.convert_jsx_expr_container(ec)) + } + swc::JSXAttrValue::JSXElement(el) => { + JSXAttributeValue::JSXElement(Box::new(self.convert_jsx_element(el))) + } + swc::JSXAttrValue::JSXFragment(frag) => { + JSXAttributeValue::JSXFragment(self.convert_jsx_fragment(frag)) + } + } + } + + fn convert_jsx_expr_container(&self, ec: &swc::JSXExprContainer) -> JSXExpressionContainer { + JSXExpressionContainer { + base: self.make_base_node(ec.span), + expression: match &ec.expr { + swc::JSXExpr::JSXEmptyExpr(e) => { + JSXExpressionContainerExpr::JSXEmptyExpression(JSXEmptyExpression { + base: self.make_base_node(e.span), + }) + } + swc::JSXExpr::Expr(e) => { + JSXExpressionContainerExpr::Expression(Box::new(self.convert_expression(e))) + } + }, + } + } + + fn convert_jsx_child(&self, child: &swc::JSXElementChild) -> JSXChild { + match child { + swc::JSXElementChild::JSXText(t) => JSXChild::JSXText(JSXText { + base: self.make_base_node(t.span), + value: t.value.to_string(), + }), + swc::JSXElementChild::JSXExprContainer(ec) => { + JSXChild::JSXExpressionContainer(self.convert_jsx_expr_container(ec)) + } + swc::JSXElementChild::JSXSpreadChild(s) => JSXChild::JSXSpreadChild(JSXSpreadChild { + base: self.make_base_node(s.span), + expression: Box::new(self.convert_expression(&s.expr)), + }), + swc::JSXElementChild::JSXElement(el) => { + JSXChild::JSXElement(Box::new(self.convert_jsx_element(el))) + } + swc::JSXElementChild::JSXFragment(frag) => { + JSXChild::JSXFragment(self.convert_jsx_fragment(frag)) + } + } + } + + fn convert_jsx_fragment(&self, frag: &swc::JSXFragment) -> JSXFragment { + JSXFragment { + base: self.make_base_node(frag.span), + opening_fragment: JSXOpeningFragment { + base: self.make_base_node(frag.opening.span), + }, + closing_fragment: JSXClosingFragment { + base: self.make_base_node(frag.closing.span), + }, + children: frag + .children + .iter() + .map(|c| self.convert_jsx_child(c)) + .collect(), + } + } + + // ===== Import/Export ===== + + fn convert_import_declaration(&self, decl: &swc::ImportDecl) -> ImportDeclaration { + ImportDeclaration { + base: self.make_base_node(decl.span), + specifiers: decl + .specifiers + .iter() + .map(|s| self.convert_import_specifier(s)) + .collect(), + source: StringLiteral { + base: self.make_base_node(decl.src.span), + value: wtf8_to_string(&decl.src.value), + }, + import_kind: if decl.type_only { + Some(ImportKind::Type) + } else { + Some(ImportKind::Value) + }, + assertions: None, + attributes: decl + .with + .as_ref() + .map(|with| self.convert_object_lit_to_import_attributes(with)), + } + } + + fn convert_object_lit_to_import_attributes( + &self, + obj: &swc::ObjectLit, + ) -> Vec { + obj.props + .iter() + .filter_map(|prop| { + if let swc::PropOrSpread::Prop(p) = prop { + if let swc::Prop::KeyValue(kv) = &**p { + let (key_name, key_span) = match &kv.key { + swc::PropName::Ident(id) => (id.sym.to_string(), id.span), + swc::PropName::Str(s) => (wtf8_to_string(&s.value), s.span), + swc::PropName::Num(n) => (n.value.to_string(), n.span), + _ => return None, + }; + if let swc::Expr::Lit(swc::Lit::Str(s)) = &*kv.value { + return Some(ImportAttribute { + base: self.make_base_node(kv.span()), + key: Identifier { + base: self.make_base_node(key_span), + name: key_name, + type_annotation: None, + optional: None, + decorators: None, + }, + value: StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }, + }); + } + } + } + None + }) + .collect() + } + + fn convert_import_specifier(&self, spec: &swc::ImportSpecifier) -> ImportSpecifier { + match spec { + swc::ImportSpecifier::Named(s) => { + let local = self.convert_ident_to_identifier(&s.local); + let imported = s + .imported + .as_ref() + .map(|i| match i { + swc::ModuleExportName::Ident(id) => { + ModuleExportName::Identifier(self.convert_ident_to_identifier(id)) + } + swc::ModuleExportName::Str(s) => { + ModuleExportName::StringLiteral(StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }) + } + }) + .unwrap_or_else(|| ModuleExportName::Identifier(local.clone())); + ImportSpecifier::ImportSpecifier(ImportSpecifierData { + base: self.make_base_node(s.span), + local, + imported, + import_kind: if s.is_type_only { + Some(ImportKind::Type) + } else { + Some(ImportKind::Value) + }, + }) + } + swc::ImportSpecifier::Default(s) => { + ImportSpecifier::ImportDefaultSpecifier(ImportDefaultSpecifierData { + base: self.make_base_node(s.span), + local: self.convert_ident_to_identifier(&s.local), + }) + } + swc::ImportSpecifier::Namespace(s) => { + ImportSpecifier::ImportNamespaceSpecifier(ImportNamespaceSpecifierData { + base: self.make_base_node(s.span), + local: self.convert_ident_to_identifier(&s.local), + }) + } + } + } + + fn convert_export_decl(&self, decl: &swc::ExportDecl) -> ExportNamedDeclaration { + ExportNamedDeclaration { + base: self.make_base_node(decl.span), + declaration: Some(Box::new(self.convert_decl_to_declaration(&decl.decl))), + specifiers: vec![], + source: None, + export_kind: Some(ExportKind::Value), + assertions: None, + attributes: None, + } + } + + fn convert_export_named(&self, decl: &swc::NamedExport) -> ExportNamedDeclaration { + ExportNamedDeclaration { + base: self.make_base_node(decl.span), + declaration: None, + specifiers: decl + .specifiers + .iter() + .map(|s| self.convert_export_specifier(s)) + .collect(), + source: decl.src.as_ref().map(|s| StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }), + export_kind: if decl.type_only { + Some(ExportKind::Type) + } else { + Some(ExportKind::Value) + }, + assertions: None, + attributes: decl + .with + .as_ref() + .map(|with| self.convert_object_lit_to_import_attributes(with)), + } + } + + fn convert_export_default_decl( + &self, + decl: &swc::ExportDefaultDecl, + ) -> ExportDefaultDeclaration { + let declaration = match &decl.decl { + swc::DefaultDecl::Fn(f) => { + let func = &f.function; + let body = func + .body + .as_ref() + .map(|b| self.convert_block_statement(b)) + .unwrap_or_else(|| BlockStatement { + base: self.make_base_node(func.span), + body: vec![], + directives: vec![], + }); + ExportDefaultDecl::FunctionDeclaration(FunctionDeclaration { + base: self.make_base_node(func.span), + id: f + .ident + .as_ref() + .map(|id| self.convert_ident_to_identifier(id)), + params: self.convert_params(&func.params), + body, + generator: func.is_generator, + is_async: func.is_async, + declare: None, + return_type: func + .return_type + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + type_parameters: func + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + predicate: None, + component_declaration: false, + hook_declaration: false, + }) + } + swc::DefaultDecl::Class(c) => { + let class = &c.class; + ExportDefaultDecl::ClassDeclaration(ClassDeclaration { + base: self.make_base_node(class.span), + id: c + .ident + .as_ref() + .map(|id| self.convert_ident_to_identifier(id)), + super_class: class + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))), + body: ClassBody { + base: self.make_base_node(class.span), + body: vec![], + }, + decorators: None, + is_abstract: if class.is_abstract { Some(true) } else { None }, + declare: None, + implements: None, + super_type_parameters: None, + type_parameters: class + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + mixins: None, + }) + } + swc::DefaultDecl::TsInterfaceDecl(_) => { + ExportDefaultDecl::Expression(Box::new(Expression::NullLiteral(NullLiteral { + base: self.make_base_node(decl.span), + }))) + } + }; + ExportDefaultDeclaration { + base: self.make_base_node(decl.span), + declaration: Box::new(declaration), + export_kind: None, + } + } + + fn convert_export_default_expr( + &self, + decl: &swc::ExportDefaultExpr, + ) -> ExportDefaultDeclaration { + ExportDefaultDeclaration { + base: self.make_base_node(decl.span), + declaration: Box::new(ExportDefaultDecl::Expression(Box::new( + self.convert_expression(&decl.expr), + ))), + export_kind: None, + } + } + + fn convert_export_all(&self, decl: &swc::ExportAll) -> ExportAllDeclaration { + ExportAllDeclaration { + base: self.make_base_node(decl.span), + source: StringLiteral { + base: self.make_base_node(decl.src.span), + value: wtf8_to_string(&decl.src.value), + }, + export_kind: if decl.type_only { + Some(ExportKind::Type) + } else { + Some(ExportKind::Value) + }, + assertions: None, + attributes: decl + .with + .as_ref() + .map(|with| self.convert_object_lit_to_import_attributes(with)), + } + } + + fn convert_decl_to_declaration(&self, decl: &swc::Decl) -> Declaration { + match decl { + swc::Decl::Var(v) => { + Declaration::VariableDeclaration(self.convert_variable_declaration(v)) + } + swc::Decl::Fn(f) => Declaration::FunctionDeclaration(self.convert_fn_decl(f)), + swc::Decl::Class(c) => Declaration::ClassDeclaration(self.convert_class_decl(c)), + swc::Decl::TsTypeAlias(d) => { + Declaration::TSTypeAliasDeclaration(self.convert_ts_type_alias(d)) + } + swc::Decl::TsInterface(d) => { + Declaration::TSInterfaceDeclaration(self.convert_ts_interface(d)) + } + swc::Decl::TsEnum(d) => Declaration::TSEnumDeclaration(self.convert_ts_enum(d)), + swc::Decl::TsModule(d) => Declaration::TSModuleDeclaration(self.convert_ts_module(d)), + swc::Decl::Using(u) => Declaration::VariableDeclaration(self.convert_using_decl(u)), + } + } + + fn convert_export_specifier(&self, spec: &swc::ExportSpecifier) -> ExportSpecifier { + match spec { + swc::ExportSpecifier::Named(s) => { + let local = self.convert_module_export_name(&s.orig); + let exported = s + .exported + .as_ref() + .map(|e| self.convert_module_export_name(e)) + .unwrap_or_else(|| local.clone()); + ExportSpecifier::ExportSpecifier(ExportSpecifierData { + base: self.make_base_node(s.span), + local, + exported, + export_kind: if s.is_type_only { + Some(ExportKind::Type) + } else { + Some(ExportKind::Value) + }, + }) + } + swc::ExportSpecifier::Default(s) => { + ExportSpecifier::ExportDefaultSpecifier(ExportDefaultSpecifierData { + base: self.make_base_node(s.exported.span), + exported: self.convert_ident_to_identifier(&s.exported), + }) + } + swc::ExportSpecifier::Namespace(s) => { + ExportSpecifier::ExportNamespaceSpecifier(ExportNamespaceSpecifierData { + base: self.make_base_node(s.span), + exported: self.convert_module_export_name(&s.name), + }) + } + } + } + + fn convert_module_export_name(&self, name: &swc::ModuleExportName) -> ModuleExportName { + match name { + swc::ModuleExportName::Ident(id) => { + ModuleExportName::Identifier(self.convert_ident_to_identifier(id)) + } + swc::ModuleExportName::Str(s) => ModuleExportName::StringLiteral(StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }), + } + } + + // ===== TS declarations ===== + + fn convert_ts_type_alias(&self, d: &swc::TsTypeAliasDecl) -> TSTypeAliasDeclaration { + TSTypeAliasDeclaration { + base: self.make_base_node(d.span), + id: self.convert_ident_to_identifier(&d.id), + type_annotation: Box::new(serde_json::Value::Null), + type_parameters: d + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + declare: if d.declare { Some(true) } else { None }, + } + } + + fn convert_ts_interface(&self, d: &swc::TsInterfaceDecl) -> TSInterfaceDeclaration { + TSInterfaceDeclaration { + base: self.make_base_node(d.span), + id: self.convert_ident_to_identifier(&d.id), + body: Box::new(serde_json::Value::Null), + type_parameters: d + .type_params + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + extends: if d.extends.is_empty() { + None + } else { + Some(vec![]) + }, + declare: if d.declare { Some(true) } else { None }, + } + } + + fn convert_ts_enum(&self, d: &swc::TsEnumDecl) -> TSEnumDeclaration { + TSEnumDeclaration { + base: self.make_base_node(d.span), + id: self.convert_ident_to_identifier(&d.id), + members: vec![], + declare: if d.declare { Some(true) } else { None }, + is_const: if d.is_const { Some(true) } else { None }, + } + } + + fn convert_ts_module(&self, d: &swc::TsModuleDecl) -> TSModuleDeclaration { + TSModuleDeclaration { + base: self.make_base_node(d.span), + id: Box::new(serde_json::Value::Null), + body: Box::new(serde_json::Value::Null), + declare: if d.declare { Some(true) } else { None }, + global: if d.global { Some(true) } else { None }, + } + } + + // ===== Identifiers ===== + + fn convert_ident_to_identifier(&self, id: &swc::Ident) -> Identifier { + Identifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + type_annotation: None, + optional: if id.optional { Some(true) } else { None }, + decorators: None, + } + } + + fn convert_binding_ident(&self, id: &swc::BindingIdent) -> Identifier { + Identifier { + base: self.make_base_node(id.id.span), + name: id.id.sym.to_string(), + type_annotation: id + .type_ann + .as_ref() + .map(|_| Box::new(serde_json::Value::Null)), + optional: if id.id.optional { Some(true) } else { None }, + decorators: None, + } + } + + fn convert_prop_name(&self, key: &swc::PropName) -> Expression { + match key { + swc::PropName::Ident(id) => Expression::Identifier(Identifier { + base: self.make_base_node(id.span), + name: id.sym.to_string(), + type_annotation: None, + optional: None, + decorators: None, + }), + swc::PropName::Str(s) => Expression::StringLiteral(StringLiteral { + base: self.make_base_node(s.span), + value: wtf8_to_string(&s.value), + }), + swc::PropName::Num(n) => Expression::NumericLiteral(NumericLiteral { + base: self.make_base_node(n.span), + value: n.value, + }), + swc::PropName::Computed(c) => self.convert_expression(&c.expr), + swc::PropName::BigInt(b) => Expression::BigIntLiteral(BigIntLiteral { + base: self.make_base_node(b.span), + value: b.value.to_string(), + }), + } + } + + // ===== Operators ===== + + fn convert_binary_operator(&self, op: swc::BinaryOp) -> BinaryOperator { + match op { + swc::BinaryOp::EqEq => BinaryOperator::Eq, + swc::BinaryOp::NotEq => BinaryOperator::Neq, + swc::BinaryOp::EqEqEq => BinaryOperator::StrictEq, + swc::BinaryOp::NotEqEq => BinaryOperator::StrictNeq, + swc::BinaryOp::Lt => BinaryOperator::Lt, + swc::BinaryOp::LtEq => BinaryOperator::Lte, + swc::BinaryOp::Gt => BinaryOperator::Gt, + swc::BinaryOp::GtEq => BinaryOperator::Gte, + swc::BinaryOp::LShift => BinaryOperator::Shl, + swc::BinaryOp::RShift => BinaryOperator::Shr, + swc::BinaryOp::ZeroFillRShift => BinaryOperator::UShr, + swc::BinaryOp::Add => BinaryOperator::Add, + swc::BinaryOp::Sub => BinaryOperator::Sub, + swc::BinaryOp::Mul => BinaryOperator::Mul, + swc::BinaryOp::Div => BinaryOperator::Div, + swc::BinaryOp::Mod => BinaryOperator::Rem, + swc::BinaryOp::Exp => BinaryOperator::Exp, + swc::BinaryOp::BitOr => BinaryOperator::BitOr, + swc::BinaryOp::BitXor => BinaryOperator::BitXor, + swc::BinaryOp::BitAnd => BinaryOperator::BitAnd, + swc::BinaryOp::In => BinaryOperator::In, + swc::BinaryOp::InstanceOf => BinaryOperator::Instanceof, + swc::BinaryOp::LogicalOr + | swc::BinaryOp::LogicalAnd + | swc::BinaryOp::NullishCoalescing => BinaryOperator::Eq, + } + } + + fn try_convert_logical_operator(&self, op: swc::BinaryOp) -> Option { + match op { + swc::BinaryOp::LogicalOr => Some(LogicalOperator::Or), + swc::BinaryOp::LogicalAnd => Some(LogicalOperator::And), + swc::BinaryOp::NullishCoalescing => Some(LogicalOperator::NullishCoalescing), + _ => None, + } + } + + fn convert_unary_operator(&self, op: swc::UnaryOp) -> UnaryOperator { + match op { + swc::UnaryOp::Minus => UnaryOperator::Neg, + swc::UnaryOp::Plus => UnaryOperator::Plus, + swc::UnaryOp::Bang => UnaryOperator::Not, + swc::UnaryOp::Tilde => UnaryOperator::BitNot, + swc::UnaryOp::TypeOf => UnaryOperator::TypeOf, + swc::UnaryOp::Void => UnaryOperator::Void, + swc::UnaryOp::Delete => UnaryOperator::Delete, + } + } + + fn convert_update_operator(&self, op: swc::UpdateOp) -> UpdateOperator { + match op { + swc::UpdateOp::PlusPlus => UpdateOperator::Increment, + swc::UpdateOp::MinusMinus => UpdateOperator::Decrement, + } + } + + fn convert_assignment_operator(&self, op: swc::AssignOp) -> AssignmentOperator { + match op { + swc::AssignOp::Assign => AssignmentOperator::Assign, + swc::AssignOp::AddAssign => AssignmentOperator::AddAssign, + swc::AssignOp::SubAssign => AssignmentOperator::SubAssign, + swc::AssignOp::MulAssign => AssignmentOperator::MulAssign, + swc::AssignOp::DivAssign => AssignmentOperator::DivAssign, + swc::AssignOp::ModAssign => AssignmentOperator::RemAssign, + swc::AssignOp::ExpAssign => AssignmentOperator::ExpAssign, + swc::AssignOp::LShiftAssign => AssignmentOperator::ShlAssign, + swc::AssignOp::RShiftAssign => AssignmentOperator::ShrAssign, + swc::AssignOp::ZeroFillRShiftAssign => AssignmentOperator::UShrAssign, + swc::AssignOp::BitOrAssign => AssignmentOperator::BitOrAssign, + swc::AssignOp::BitXorAssign => AssignmentOperator::BitXorAssign, + swc::AssignOp::BitAndAssign => AssignmentOperator::BitAndAssign, + swc::AssignOp::OrAssign => AssignmentOperator::OrAssign, + swc::AssignOp::AndAssign => AssignmentOperator::AndAssign, + swc::AssignOp::NullishAssign => AssignmentOperator::NullishAssign, + } + } +} diff --git a/crates/swc_ecma_react_compiler/src/convert_ast_reverse.rs b/crates/swc_ecma_react_compiler/src/convert_ast_reverse.rs new file mode 100644 index 000000000000..6b364fe76381 --- /dev/null +++ b/crates/swc_ecma_react_compiler/src/convert_ast_reverse.rs @@ -0,0 +1,2536 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +//! Reverse AST converter: react_compiler_ast (Babel format) → SWC AST. +//! +//! This is the inverse of `convert_ast.rs`. It takes a +//! `react_compiler_ast::File` (which represents the compiler's Babel-compatible +//! output) and produces SWC AST nodes suitable for code generation via +//! `swc_codegen`. + +use react_compiler_ast::{ + common::{BaseNode, Comment as BabelComment}, + declarations::{ + ExportAllDeclaration, ExportDefaultDecl as BabelExportDefaultDecl, + ExportDefaultDeclaration, ExportKind, ExportNamedDeclaration, ImportDeclaration, + ImportKind, + }, + expressions::{self as babel_expr, Expression as BabelExpr}, + operators::*, + patterns::*, + statements::{self as babel_stmt, Statement as BabelStmt}, +}; +use swc_atoms::{Atom, Wtf8Atom}; +use swc_common::{ + comments::{Comment as SwcComment, CommentKind, Comments, SingleThreadedComments}, + BytePos, Span, Spanned, SyntaxContext, DUMMY_SP, +}; +use swc_ecma_ast::*; + +/// Result of converting a Babel AST back to SWC, including extracted comments. +pub struct SwcConversionResult { + pub module: Module, + pub comments: SingleThreadedComments, +} + +/// Convert a `react_compiler_ast::File` into an SWC `Module` and extracted +/// comments. +pub fn convert_program_to_swc(file: &react_compiler_ast::File) -> SwcConversionResult { + convert_program_to_swc_with_source(file, None) +} + +/// Convert a `react_compiler_ast::File` into an SWC `Module` and extracted +/// comments. When `source_text` is provided, type declarations can be extracted +/// from the original source for perfect fidelity. +pub fn convert_program_to_swc_with_source( + file: &react_compiler_ast::File, + source_text: Option<&str>, +) -> SwcConversionResult { + let ctx = ReverseCtx { + comments: SingleThreadedComments::default(), + source_text: source_text.map(|s| s.to_string()), + }; + let module = ctx.convert_program(&file.program); + SwcConversionResult { + module, + comments: ctx.comments, + } +} + +struct ReverseCtx { + comments: SingleThreadedComments, + source_text: Option, +} + +impl ReverseCtx { + /// Convert a BaseNode's start/end to an SWC Span, and extract any comments. + fn span(&self, base: &BaseNode) -> Span { + let span = match (base.start, base.end) { + (Some(start), Some(end)) => Span::new(BytePos(start), BytePos(end)), + _ => DUMMY_SP, + }; + self.extract_comments(base, span); + span + } + + /// Convert a BaseNode's start/end to an SWC Span without extracting + /// comments. Use this for sub-nodes where comments should not be + /// duplicated. + fn span_no_comments(&self, base: &BaseNode) -> Span { + match (base.start, base.end) { + (Some(start), Some(end)) => Span::new(BytePos(start), BytePos(end)), + _ => DUMMY_SP, + } + } + + /// Convert a Babel comment to an SWC comment. + fn convert_babel_comment(babel_comment: &BabelComment) -> SwcComment { + let (kind, text) = match babel_comment { + BabelComment::CommentBlock(data) => (CommentKind::Block, &data.value), + BabelComment::CommentLine(data) => (CommentKind::Line, &data.value), + }; + SwcComment { + kind, + span: DUMMY_SP, + text: Atom::from(text.as_str()), + } + } + + /// Extract comments from a BaseNode and register them with the SWC comments + /// store. + fn extract_comments(&self, base: &BaseNode, span: Span) { + if let Some(ref leading) = base.leading_comments { + let pos = span.lo; + for c in leading { + self.comments + .add_leading(pos, Self::convert_babel_comment(c)); + } + } + if let Some(ref trailing) = base.trailing_comments { + let pos = span.hi; + for c in trailing { + self.comments + .add_trailing(pos, Self::convert_babel_comment(c)); + } + } + if let Some(ref inner) = base.inner_comments { + // Inner comments are typically leading comments of the next token + let pos = span.lo; + for c in inner { + self.comments + .add_leading(pos, Self::convert_babel_comment(c)); + } + } + } + + fn atom(&self, s: &str) -> Atom { + Atom::from(s) + } + + fn wtf8(&self, s: &str) -> Wtf8Atom { + Wtf8Atom::from(s) + } + + /// Escape non-ASCII characters and special characters (like tab) in a + /// string value to \uXXXX or \xXX sequences, matching Babel's codegen + /// output. Returns the raw string representation wrapped in double + /// quotes. + fn escape_string_raw(&self, value: &str) -> Option { + let mut needs_escape = false; + for ch in value.chars() { + if !ch.is_ascii() || ch == '\t' || ch == '\'' || ch == '"' || ch == '\\' { + needs_escape = true; + break; + } + } + if !needs_escape { + return None; + } + let mut escaped = String::with_capacity(value.len() + 16); + escaped.push('"'); + for ch in value.chars() { + match ch { + '"' => escaped.push_str("\\\""), + '\\' => escaped.push_str("\\\\"), + '\n' => escaped.push_str("\\n"), + '\r' => escaped.push_str("\\r"), + '\t' => escaped.push_str("\\t"), + c if !c.is_ascii() => { + // Encode using \uXXXX (or surrogate pairs for chars > U+FFFF) + let mut buf = [0u16; 2]; + let encoded = c.encode_utf16(&mut buf); + for unit in encoded { + escaped.push_str(&format!("\\u{unit:04X}")); + } + } + c => escaped.push(c), + } + } + escaped.push('"'); + Some(Atom::from(escaped.as_str())) + } + + /// Extract the original source text for a node and re-parse it as a + /// statement using SWC's TypeScript parser. This is used for type + /// declarations (type aliases, interfaces, enums) that the compiler + /// preserves verbatim from the original source. + fn extract_source_stmt(&self, base: &react_compiler_ast::common::BaseNode) -> Option { + let source = self.source_text.as_deref()?; + let start = base.start? as usize; + let end = base.end? as usize; + // SWC BytePos is 1-based + let start_idx = start.saturating_sub(1); + let end_idx = end.saturating_sub(1); + if start_idx >= source.len() || end_idx > source.len() || start_idx >= end_idx { + return None; + } + let text = &source[start_idx..end_idx]; + self.parse_ts_stmt(text, base) + } + + /// Parse a string as a TypeScript statement using SWC's parser. + fn parse_ts_stmt( + &self, + text: &str, + base: &react_compiler_ast::common::BaseNode, + ) -> Option { + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let fm = cm.new_source_file( + swc_common::sync::Lrc::new(swc_common::FileName::Anon), + text.to_string(), + ); + let mut errors = vec![]; + let module = swc_ecma_parser::parse_file_as_module( + &fm, + swc_ecma_parser::Syntax::Typescript(swc_ecma_parser::TsSyntax { + tsx: true, + ..Default::default() + }), + swc_ecma_ast::EsVersion::latest(), + None, + &mut errors, + ) + .ok()?; + + if let Some(item) = module.body.into_iter().next() { + match item { + ModuleItem::Stmt(stmt) => { + // Assign the original span so blank line computation works + let span = self.span(base); + return Some(self.assign_span_to_stmt(stmt, span)); + } + ModuleItem::ModuleDecl(_) => {} + } + } + None + } + + /// Assign a span to a statement's outermost node. + fn assign_span_to_stmt(&self, stmt: Stmt, span: Span) -> Stmt { + match stmt { + Stmt::Decl(Decl::TsTypeAlias(mut d)) => { + d.span = span; + Stmt::Decl(Decl::TsTypeAlias(d)) + } + Stmt::Decl(Decl::TsInterface(mut d)) => { + d.span = span; + Stmt::Decl(Decl::TsInterface(d)) + } + Stmt::Decl(Decl::TsEnum(mut d)) => { + d.span = span; + Stmt::Decl(Decl::TsEnum(d)) + } + other => other, + } + } + + fn ident(&self, name: &str, span: Span) -> Ident { + Ident { + sym: self.atom(name), + span, + ctxt: SyntaxContext::empty(), + optional: false, + } + } + + fn ident_name(&self, name: &str, span: Span) -> IdentName { + IdentName { + sym: self.atom(name), + span, + } + } + + fn binding_ident(&self, name: &str, span: Span) -> BindingIdent { + BindingIdent { + id: self.ident(name, span), + type_ann: None, + } + } + + // ===== Program ===== + + fn convert_program(&self, program: &react_compiler_ast::Program) -> Module { + let mut body: Vec = Vec::new(); + + // Convert directives to expression statements at the beginning + for dir in &program.directives { + let span = self.span(&dir.base); + let str_span = self.span(&dir.value.base); + body.push(ModuleItem::Stmt(Stmt::Expr(ExprStmt { + span, + expr: Box::new(Expr::Lit(Lit::Str(Str { + span: str_span, + value: self.wtf8(&dir.value.value), + raw: None, + }))), + }))); + } + + for s in &program.body { + body.push(self.convert_statement_to_module_item(s)); + } + + Module { + span: DUMMY_SP, + body, + shebang: None, + } + } + + fn convert_statement_to_module_item(&self, stmt: &BabelStmt) -> ModuleItem { + match stmt { + BabelStmt::ImportDeclaration(d) => { + ModuleItem::ModuleDecl(ModuleDecl::Import(self.convert_import_declaration(d))) + } + BabelStmt::ExportNamedDeclaration(d) => self.convert_export_named_to_module_item(d), + BabelStmt::ExportDefaultDeclaration(d) => self.convert_export_default_to_module_item(d), + BabelStmt::ExportAllDeclaration(d) => ModuleItem::ModuleDecl(ModuleDecl::ExportAll( + self.convert_export_all_declaration(d), + )), + _ => ModuleItem::Stmt(self.convert_statement(stmt)), + } + } + + // ===== Statements ===== + + fn convert_statement(&self, stmt: &BabelStmt) -> Stmt { + match stmt { + BabelStmt::BlockStatement(s) => Stmt::Block(self.convert_block_statement(s)), + BabelStmt::ReturnStatement(s) => Stmt::Return(ReturnStmt { + span: self.span(&s.base), + arg: s + .argument + .as_ref() + .map(|a| Box::new(self.convert_expression(a))), + }), + BabelStmt::ExpressionStatement(s) => { + let expr = self.convert_expression(&s.expression); + // Wrap in parens if the expression starts with `{` (object pattern + // in assignment) or `function` (IIFE), which would be ambiguous + // with a block statement or function declaration. + let needs_paren = match &expr { + Expr::Assign(a) => { + matches!(&a.left, AssignTarget::Pat(AssignTargetPat::Object(_))) + } + Expr::Call(c) => match &c.callee { + Callee::Expr(e) => matches!(e.as_ref(), Expr::Fn(_)), + _ => false, + }, + _ => false, + }; + let expr = if needs_paren { + Expr::Paren(ParenExpr { + span: self.span_no_comments(&s.base), + expr: Box::new(expr), + }) + } else { + expr + }; + Stmt::Expr(ExprStmt { + span: self.span(&s.base), + expr: Box::new(expr), + }) + } + BabelStmt::IfStatement(s) => Stmt::If(IfStmt { + span: self.span(&s.base), + test: Box::new(self.convert_expression(&s.test)), + cons: Box::new(self.convert_statement(&s.consequent)), + alt: s + .alternate + .as_ref() + .map(|a| Box::new(self.convert_statement(a))), + }), + BabelStmt::ForStatement(s) => { + let init = s.init.as_ref().map(|i| self.convert_for_init(i)); + let test = s + .test + .as_ref() + .map(|t| Box::new(self.convert_expression(t))); + let update = s + .update + .as_ref() + .map(|u| Box::new(self.convert_expression(u))); + let body = Box::new(self.convert_statement(&s.body)); + Stmt::For(ForStmt { + span: self.span(&s.base), + init, + test, + update, + body, + }) + } + BabelStmt::WhileStatement(s) => Stmt::While(WhileStmt { + span: self.span(&s.base), + test: Box::new(self.convert_expression(&s.test)), + body: Box::new(self.convert_statement(&s.body)), + }), + BabelStmt::DoWhileStatement(s) => Stmt::DoWhile(DoWhileStmt { + span: self.span(&s.base), + test: Box::new(self.convert_expression(&s.test)), + body: Box::new(self.convert_statement(&s.body)), + }), + BabelStmt::ForInStatement(s) => Stmt::ForIn(ForInStmt { + span: self.span(&s.base), + left: self.convert_for_in_of_left(&s.left), + right: Box::new(self.convert_expression(&s.right)), + body: Box::new(self.convert_statement(&s.body)), + }), + BabelStmt::ForOfStatement(s) => Stmt::ForOf(ForOfStmt { + span: self.span(&s.base), + is_await: s.is_await, + left: self.convert_for_in_of_left(&s.left), + right: Box::new(self.convert_expression(&s.right)), + body: Box::new(self.convert_statement(&s.body)), + }), + BabelStmt::SwitchStatement(s) => { + let cases = s + .cases + .iter() + .map(|c| SwitchCase { + span: self.span(&c.base), + test: c + .test + .as_ref() + .map(|t| Box::new(self.convert_expression(t))), + cons: c + .consequent + .iter() + .map(|s| self.convert_statement(s)) + .collect(), + }) + .collect(); + Stmt::Switch(SwitchStmt { + span: self.span(&s.base), + discriminant: Box::new(self.convert_expression(&s.discriminant)), + cases, + }) + } + BabelStmt::ThrowStatement(s) => Stmt::Throw(ThrowStmt { + span: self.span(&s.base), + arg: Box::new(self.convert_expression(&s.argument)), + }), + BabelStmt::TryStatement(s) => { + let block = self.convert_block_statement(&s.block); + let handler = s.handler.as_ref().map(|h| self.convert_catch_clause(h)); + let finalizer = s + .finalizer + .as_ref() + .map(|f| self.convert_block_statement(f)); + Stmt::Try(Box::new(TryStmt { + span: self.span(&s.base), + block, + handler, + finalizer, + })) + } + BabelStmt::BreakStatement(s) => Stmt::Break(BreakStmt { + span: self.span(&s.base), + label: s.label.as_ref().map(|l| self.ident(&l.name, DUMMY_SP)), + }), + BabelStmt::ContinueStatement(s) => Stmt::Continue(ContinueStmt { + span: self.span(&s.base), + label: s.label.as_ref().map(|l| self.ident(&l.name, DUMMY_SP)), + }), + BabelStmt::LabeledStatement(s) => Stmt::Labeled(LabeledStmt { + span: self.span(&s.base), + label: self.ident(&s.label.name, DUMMY_SP), + body: Box::new(self.convert_statement(&s.body)), + }), + BabelStmt::EmptyStatement(s) => Stmt::Empty(EmptyStmt { + span: self.span(&s.base), + }), + BabelStmt::DebuggerStatement(s) => Stmt::Debugger(DebuggerStmt { + span: self.span(&s.base), + }), + BabelStmt::WithStatement(s) => Stmt::With(WithStmt { + span: self.span(&s.base), + obj: Box::new(self.convert_expression(&s.object)), + body: Box::new(self.convert_statement(&s.body)), + }), + BabelStmt::VariableDeclaration(d) => { + Stmt::Decl(Decl::Var(Box::new(self.convert_variable_declaration(d)))) + } + BabelStmt::FunctionDeclaration(f) => { + Stmt::Decl(Decl::Fn(self.convert_function_declaration(f))) + } + BabelStmt::ClassDeclaration(c) => { + let ident = + c.id.as_ref() + .map(|id| self.ident(&id.name, self.span(&id.base))) + .unwrap_or_else(|| self.ident("_anonymous", DUMMY_SP)); + let super_class = c + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))); + Stmt::Decl(Decl::Class(ClassDecl { + ident, + declare: c.declare.unwrap_or(false), + class: Box::new(Class { + span: self.span(&c.base), + ctxt: SyntaxContext::empty(), + decorators: vec![], + body: vec![], + super_class, + is_abstract: false, + type_params: None, + super_type_params: None, + implements: vec![], + }), + })) + } + // Import/export handled in convert_statement_to_module_item + BabelStmt::ImportDeclaration(_) + | BabelStmt::ExportNamedDeclaration(_) + | BabelStmt::ExportDefaultDeclaration(_) + | BabelStmt::ExportAllDeclaration(_) => Stmt::Empty(EmptyStmt { span: DUMMY_SP }), + // TS declarations - extract from source text if available + BabelStmt::TSTypeAliasDeclaration(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + BabelStmt::TSInterfaceDeclaration(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + BabelStmt::TSEnumDeclaration(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + // Flow type declarations - extract from source text if available + BabelStmt::TypeAlias(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + BabelStmt::OpaqueType(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + BabelStmt::InterfaceDeclaration(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + BabelStmt::EnumDeclaration(d) => self + .extract_source_stmt(&d.base) + .unwrap_or(Stmt::Empty(EmptyStmt { span: DUMMY_SP })), + // Other TS/Flow declarations + BabelStmt::TSModuleDeclaration(_) + | BabelStmt::TSDeclareFunction(_) + | BabelStmt::DeclareVariable(_) + | BabelStmt::DeclareFunction(_) + | BabelStmt::DeclareClass(_) + | BabelStmt::DeclareModule(_) + | BabelStmt::DeclareModuleExports(_) + | BabelStmt::DeclareExportDeclaration(_) + | BabelStmt::DeclareExportAllDeclaration(_) + | BabelStmt::DeclareInterface(_) + | BabelStmt::DeclareTypeAlias(_) + | BabelStmt::DeclareOpaqueType(_) => Stmt::Empty(EmptyStmt { span: DUMMY_SP }), + } + } + + fn convert_block_statement(&self, block: &babel_stmt::BlockStatement) -> BlockStmt { + let mut stmts: Vec = Vec::new(); + + // Convert directives to expression statements at the beginning + for dir in &block.directives { + let span = self.span(&dir.base); + let str_span = self.span(&dir.value.base); + stmts.push(Stmt::Expr(ExprStmt { + span, + expr: Box::new(Expr::Lit(Lit::Str(Str { + span: str_span, + value: self.wtf8(&dir.value.value), + raw: None, + }))), + })); + } + + for s in &block.body { + stmts.push(self.convert_statement(s)); + } + + BlockStmt { + span: self.span(&block.base), + ctxt: SyntaxContext::empty(), + stmts, + } + } + + fn convert_catch_clause(&self, clause: &babel_stmt::CatchClause) -> CatchClause { + let param = clause.param.as_ref().map(|p| self.convert_pattern(p)); + CatchClause { + span: self.span(&clause.base), + param, + body: self.convert_block_statement(&clause.body), + } + } + + fn convert_for_init(&self, init: &babel_stmt::ForInit) -> VarDeclOrExpr { + match init { + babel_stmt::ForInit::VariableDeclaration(v) => { + VarDeclOrExpr::VarDecl(Box::new(self.convert_variable_declaration(v))) + } + babel_stmt::ForInit::Expression(e) => { + VarDeclOrExpr::Expr(Box::new(self.convert_expression(e))) + } + } + } + + fn convert_for_in_of_left(&self, left: &babel_stmt::ForInOfLeft) -> ForHead { + match left { + babel_stmt::ForInOfLeft::VariableDeclaration(v) => { + ForHead::VarDecl(Box::new(self.convert_variable_declaration(v))) + } + babel_stmt::ForInOfLeft::Pattern(p) => ForHead::Pat(Box::new(self.convert_pattern(p))), + } + } + + fn convert_variable_declaration(&self, decl: &babel_stmt::VariableDeclaration) -> VarDecl { + let kind = match decl.kind { + babel_stmt::VariableDeclarationKind::Var => VarDeclKind::Var, + babel_stmt::VariableDeclarationKind::Let => VarDeclKind::Let, + babel_stmt::VariableDeclarationKind::Const => VarDeclKind::Const, + babel_stmt::VariableDeclarationKind::Using => VarDeclKind::Var, /* SWC doesn't have + * Using */ + }; + let decls = decl + .declarations + .iter() + .map(|d| self.convert_variable_declarator(d)) + .collect(); + let declare = decl.declare.unwrap_or(false); + VarDecl { + span: self.span(&decl.base), + ctxt: SyntaxContext::empty(), + kind, + declare, + decls, + } + } + + fn convert_variable_declarator(&self, d: &babel_stmt::VariableDeclarator) -> VarDeclarator { + let name = self.convert_pattern(&d.id); + let init = d + .init + .as_ref() + .map(|e| Box::new(self.convert_expression(e))); + let definite = d.definite.unwrap_or(false); + VarDeclarator { + span: self.span(&d.base), + name, + init, + definite, + } + } + + // ===== Expressions ===== + + fn convert_expression(&self, expr: &BabelExpr) -> Expr { + match expr { + BabelExpr::Identifier(id) => { + let span = self.span(&id.base); + Expr::Ident(self.ident(&id.name, span)) + } + BabelExpr::StringLiteral(lit) => Expr::Lit(Lit::Str(Str { + span: self.span(&lit.base), + value: self.wtf8(&lit.value), + raw: self.escape_string_raw(&lit.value), + })), + BabelExpr::NumericLiteral(lit) => { + // Convert -0.0 to 0.0 to match Babel's codegen behavior. + // Babel outputs `0` for both `-0` and `0`. + let value = if lit.value == 0.0 && lit.value.is_sign_negative() { + 0.0 + } else { + lit.value + }; + Expr::Lit(Lit::Num(Number { + span: self.span(&lit.base), + value, + raw: None, + })) + } + BabelExpr::BooleanLiteral(lit) => Expr::Lit(Lit::Bool(Bool { + span: self.span(&lit.base), + value: lit.value, + })), + BabelExpr::NullLiteral(lit) => Expr::Lit(Lit::Null(Null { + span: self.span(&lit.base), + })), + BabelExpr::BigIntLiteral(lit) => Expr::Lit(Lit::BigInt(BigInt { + span: self.span(&lit.base), + value: Box::new(lit.value.parse().unwrap_or_default()), + raw: None, + })), + BabelExpr::RegExpLiteral(lit) => Expr::Lit(Lit::Regex(Regex { + span: self.span(&lit.base), + exp: self.atom(&lit.pattern), + flags: self.atom(&lit.flags), + })), + BabelExpr::CallExpression(call) => { + let callee = self.convert_expression(&call.callee); + let args = self.convert_arguments(&call.arguments); + // Wrap arrow/function expressions in parens when used as + // call targets (IIFEs). SWC codegen does not add parens for + // `(() => ...)()`, resulting in incorrect code. + let callee = match &callee { + Expr::Arrow(_) | Expr::Fn(_) => Expr::Paren(ParenExpr { + span: callee.span(), + expr: Box::new(callee), + }), + _ => callee, + }; + Expr::Call(CallExpr { + span: self.span(&call.base), + ctxt: SyntaxContext::empty(), + callee: Callee::Expr(Box::new(callee)), + args, + type_args: None, + }) + } + BabelExpr::MemberExpression(m) => self.convert_member_expression(m), + BabelExpr::OptionalCallExpression(call) => { + let callee = self.convert_expression_for_chain(&call.callee); + let args = self.convert_arguments(&call.arguments); + let base = OptChainBase::Call(OptCall { + span: self.span(&call.base), + ctxt: SyntaxContext::empty(), + callee: Box::new(callee), + args, + type_args: None, + }); + Expr::OptChain(OptChainExpr { + span: self.span(&call.base), + optional: call.optional, + base: Box::new(base), + }) + } + BabelExpr::OptionalMemberExpression(m) => { + let base = self.convert_optional_member_to_chain_base(m); + Expr::OptChain(OptChainExpr { + span: self.span(&m.base), + optional: m.optional, + base: Box::new(base), + }) + } + BabelExpr::BinaryExpression(bin) => { + let op = self.convert_binary_operator(&bin.operator); + Expr::Bin(BinExpr { + span: self.span(&bin.base), + op, + left: Box::new(self.convert_expression(&bin.left)), + right: Box::new(self.convert_expression(&bin.right)), + }) + } + BabelExpr::LogicalExpression(log) => { + let op = self.convert_logical_operator(&log.operator); + let span = self.span(&log.base); + let bin = Expr::Bin(BinExpr { + span, + op, + left: Box::new(self.convert_expression(&log.left)), + right: Box::new(self.convert_expression(&log.right)), + }); + // Wrap all logical expressions in parentheses. Logical + // operators (||, &&, ??) have lower precedence than most + // binary operators, but SWC's codegen does not always insert + // parens correctly (e.g., `a + b || c` vs `a + (b || c)`). + // Wrapping unconditionally is safe. + Expr::Paren(ParenExpr { + span, + expr: Box::new(bin), + }) + } + BabelExpr::UnaryExpression(un) => { + let op = self.convert_unary_operator(&un.operator); + Expr::Unary(UnaryExpr { + span: self.span(&un.base), + op, + arg: Box::new(self.convert_expression(&un.argument)), + }) + } + BabelExpr::UpdateExpression(up) => { + let op = self.convert_update_operator(&up.operator); + Expr::Update(UpdateExpr { + span: self.span(&up.base), + op, + prefix: up.prefix, + arg: Box::new(self.convert_expression(&up.argument)), + }) + } + BabelExpr::ConditionalExpression(cond) => { + let span = self.span(&cond.base); + // Wrap conditional expressions in parentheses. SWC's codegen + // does not always insert parens for ternaries inside binary + // or assignment expressions (e.g., `x + cond ? a : b` instead + // of `x + (cond ? a : b)`). + Expr::Paren(ParenExpr { + span, + expr: Box::new(Expr::Cond(CondExpr { + span, + test: Box::new(self.convert_expression(&cond.test)), + cons: Box::new(self.convert_expression(&cond.consequent)), + alt: Box::new(self.convert_expression(&cond.alternate)), + })), + }) + } + BabelExpr::AssignmentExpression(assign) => { + let op = self.convert_assignment_operator(&assign.operator); + let left = self.convert_pattern_to_assign_target(&assign.left); + let span = self.span(&assign.base); + let assign_expr = Expr::Assign(AssignExpr { + span, + op, + left, + right: Box::new(self.convert_expression(&assign.right)), + }); + // Wrap assignment expressions in parentheses. SWC's codegen + // does not always insert necessary parens for assignments + // when they appear as operands of binary/logical expressions + // (e.g., `x + x = 2` instead of `x + (x = 2)`). + Expr::Paren(ParenExpr { + span, + expr: Box::new(assign_expr), + }) + } + BabelExpr::SequenceExpression(seq) => { + let exprs = seq + .expressions + .iter() + .map(|e| Box::new(self.convert_expression(e))) + .collect(); + let span = self.span(&seq.base); + // Wrap sequence expressions in parentheses. SWC's codegen + // does not always insert necessary parens for sequence + // expressions (e.g., in ternary consequent position), so + // wrapping unconditionally is safe and prevents parse errors. + Expr::Paren(ParenExpr { + span, + expr: Box::new(Expr::Seq(SeqExpr { span, exprs })), + }) + } + BabelExpr::ArrowFunctionExpression(arrow) => self.convert_arrow_function(arrow), + BabelExpr::FunctionExpression(func) => { + let ident = func + .id + .as_ref() + .map(|id| self.ident(&id.name, self.span(&id.base))); + let params = self.convert_params(&func.params); + let body = Some(self.convert_block_statement(&func.body)); + Expr::Fn(FnExpr { + ident, + function: Box::new(Function { + params, + decorators: vec![], + span: self.span(&func.base), + ctxt: SyntaxContext::empty(), + body, + is_generator: func.generator, + is_async: func.is_async, + type_params: None, + return_type: None, + }), + }) + } + BabelExpr::ObjectExpression(obj) => { + let props = obj + .properties + .iter() + .map(|p| self.convert_object_expression_property(p)) + .collect(); + Expr::Object(ObjectLit { + span: self.span(&obj.base), + props, + }) + } + BabelExpr::ArrayExpression(arr) => { + let elems = arr + .elements + .iter() + .map(|e| self.convert_array_element(e)) + .collect(); + Expr::Array(ArrayLit { + span: self.span(&arr.base), + elems, + }) + } + BabelExpr::NewExpression(n) => { + let callee = Box::new(self.convert_expression(&n.callee)); + let args = Some(self.convert_arguments(&n.arguments)); + Expr::New(NewExpr { + span: self.span(&n.base), + ctxt: SyntaxContext::empty(), + callee, + args, + type_args: None, + }) + } + BabelExpr::TemplateLiteral(tl) => { + let template = self.convert_template_literal(tl); + Expr::Tpl(template) + } + BabelExpr::TaggedTemplateExpression(tag) => { + let t = Box::new(self.convert_expression(&tag.tag)); + let tpl = Box::new(self.convert_template_literal(&tag.quasi)); + Expr::TaggedTpl(TaggedTpl { + span: self.span(&tag.base), + ctxt: SyntaxContext::empty(), + tag: t, + type_params: None, + tpl, + }) + } + BabelExpr::AwaitExpression(a) => Expr::Await(AwaitExpr { + span: self.span(&a.base), + arg: Box::new(self.convert_expression(&a.argument)), + }), + BabelExpr::YieldExpression(y) => Expr::Yield(YieldExpr { + span: self.span(&y.base), + delegate: y.delegate, + arg: y + .argument + .as_ref() + .map(|a| Box::new(self.convert_expression(a))), + }), + BabelExpr::SpreadElement(s) => { + // SpreadElement can't be a standalone expression in SWC. + // Return the argument directly as a fallback. + self.convert_expression(&s.argument) + } + BabelExpr::MetaProperty(mp) => Expr::MetaProp(MetaPropExpr { + span: self.span(&mp.base), + kind: match (mp.meta.name.as_str(), mp.property.name.as_str()) { + ("new", "target") => MetaPropKind::NewTarget, + ("import", "meta") => MetaPropKind::ImportMeta, + _ => MetaPropKind::NewTarget, + }, + }), + BabelExpr::ClassExpression(c) => { + let ident = + c.id.as_ref() + .map(|id| self.ident(&id.name, self.span(&id.base))); + let super_class = c + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))); + Expr::Class(ClassExpr { + ident, + class: Box::new(Class { + span: self.span(&c.base), + ctxt: SyntaxContext::empty(), + decorators: vec![], + body: vec![], + super_class, + is_abstract: false, + type_params: None, + super_type_params: None, + implements: vec![], + }), + }) + } + BabelExpr::PrivateName(p) => Expr::PrivateName(PrivateName { + span: self.span(&p.base), + name: self.atom(&p.id.name), + }), + BabelExpr::Super(s) => Expr::Ident(self.ident("super", self.span(&s.base))), + BabelExpr::Import(i) => Expr::Ident(self.ident("import", self.span(&i.base))), + BabelExpr::ThisExpression(t) => Expr::This(ThisExpr { + span: self.span(&t.base), + }), + BabelExpr::ParenthesizedExpression(p) => Expr::Paren(ParenExpr { + span: self.span(&p.base), + expr: Box::new(self.convert_expression(&p.expression)), + }), + BabelExpr::JSXElement(el) => { + let element = self.convert_jsx_element(el.as_ref()); + Expr::JSXElement(Box::new(element)) + } + BabelExpr::JSXFragment(frag) => { + let fragment = self.convert_jsx_fragment(frag); + Expr::JSXFragment(fragment) + } + // TS expressions - preserve as SWC TS nodes + BabelExpr::TSAsExpression(e) => { + let expr = Box::new(self.convert_expression(&e.expression)); + let span = self.span(&e.base); + // Check if this is "as const" — Babel represents it as + // TSAsExpression with typeAnnotation: TSTypeReference { typeName: Identifier { + // name: "const" } } + let is_as_const = e.type_annotation.get("type").and_then(|v| v.as_str()) + == Some("TSTypeReference") + && e.type_annotation + .get("typeName") + .and_then(|tn| tn.get("name")) + .and_then(|n| n.as_str()) + == Some("const"); + + if is_as_const { + Expr::TsConstAssertion(TsConstAssertion { span, expr }) + } else { + let type_ann = self.convert_ts_type_from_json(&e.type_annotation, span); + Expr::TsAs(TsAsExpr { + span, + expr, + type_ann: Box::new(type_ann), + }) + } + } + BabelExpr::TSSatisfiesExpression(e) => self.convert_expression(&e.expression), + BabelExpr::TSNonNullExpression(e) => Expr::TsNonNull(TsNonNullExpr { + span: self.span(&e.base), + expr: Box::new(self.convert_expression(&e.expression)), + }), + BabelExpr::TSTypeAssertion(e) => self.convert_expression(&e.expression), + BabelExpr::TSInstantiationExpression(e) => self.convert_expression(&e.expression), + BabelExpr::TypeCastExpression(e) => self.convert_expression(&e.expression), + BabelExpr::AssignmentPattern(p) => { + let left = self.convert_pattern_to_assign_target(&p.left); + Expr::Assign(AssignExpr { + span: self.span(&p.base), + op: AssignOp::Assign, + left, + right: Box::new(self.convert_expression(&p.right)), + }) + } + } + } + + /// Convert an expression that may be used inside a chain (optional + /// chaining). + /// + /// In Babel, a chain like `a?.b.c()` is represented as nested + /// OptionalMemberExpression / OptionalCallExpression nodes. Each node + /// has an `optional` flag indicating whether it uses `?.` at that point. + /// + /// In SWC, each `?.` point is wrapped in an `OptChainExpr`. Nodes in + /// the chain that do NOT have `?.` are plain `MemberExpr` / `CallExpr`. + /// + /// So when `optional: true`, we still need to emit `OptChainExpr`. + /// When `optional: false`, we emit a plain expr (part of the parent chain). + fn convert_expression_for_chain(&self, expr: &BabelExpr) -> Expr { + match expr { + BabelExpr::OptionalMemberExpression(m) => { + if m.optional { + // This node uses `?.`, wrap in OptChainExpr + let base = self.convert_optional_member_to_chain_base(m); + Expr::OptChain(OptChainExpr { + span: self.span(&m.base), + optional: true, + base: Box::new(base), + }) + } else { + // Part of a chain but no `?.` here — plain MemberExpr + self.convert_optional_member_to_member_expr(m) + } + } + BabelExpr::OptionalCallExpression(call) => { + let callee = self.convert_expression_for_chain(&call.callee); + let args = self.convert_arguments(&call.arguments); + if call.optional { + // This node uses `?.()`, wrap in OptChainExpr + let base = OptChainBase::Call(OptCall { + span: self.span(&call.base), + ctxt: SyntaxContext::empty(), + callee: Box::new(callee), + args, + type_args: None, + }); + Expr::OptChain(OptChainExpr { + span: self.span(&call.base), + optional: true, + base: Box::new(base), + }) + } else { + // Part of a chain but no `?.` here — plain CallExpr + Expr::Call(CallExpr { + span: self.span(&call.base), + ctxt: SyntaxContext::empty(), + callee: Callee::Expr(Box::new(callee)), + args, + type_args: None, + }) + } + } + _ => self.convert_expression(expr), + } + } + + fn convert_member_expression(&self, m: &babel_expr::MemberExpression) -> Expr { + let object = self.convert_expression(&m.object); + // When an optional chain expression is used as the object of a + // non-optional member expression (e.g., `(props?.a).b`), wrap it + // in parens to properly terminate the optional chain. Without + // parens, SWC codegen emits `props?.a.b` which extends the chain. + let object = match &object { + Expr::OptChain(_) => Box::new(Expr::Paren(ParenExpr { + span: object.span(), + expr: Box::new(object), + })), + _ => Box::new(object), + }; + if m.computed { + let property = self.convert_expression(&m.property); + Expr::Member(MemberExpr { + span: self.span(&m.base), + obj: object, + prop: MemberProp::Computed(ComputedPropName { + span: DUMMY_SP, + expr: Box::new(property), + }), + }) + } else { + let prop_name = self.expression_to_ident_name(&m.property); + Expr::Member(MemberExpr { + span: self.span(&m.base), + obj: object, + prop: MemberProp::Ident(prop_name), + }) + } + } + + fn convert_optional_member_to_chain_base( + &self, + m: &babel_expr::OptionalMemberExpression, + ) -> OptChainBase { + let object = Box::new(self.convert_expression_for_chain(&m.object)); + if m.computed { + let property = self.convert_expression(&m.property); + OptChainBase::Member(MemberExpr { + span: self.span(&m.base), + obj: object, + prop: MemberProp::Computed(ComputedPropName { + span: DUMMY_SP, + expr: Box::new(property), + }), + }) + } else { + let prop_name = self.expression_to_ident_name(&m.property); + OptChainBase::Member(MemberExpr { + span: self.span(&m.base), + obj: object, + prop: MemberProp::Ident(prop_name), + }) + } + } + + fn convert_optional_member_to_member_expr( + &self, + m: &babel_expr::OptionalMemberExpression, + ) -> Expr { + let object = Box::new(self.convert_expression_for_chain(&m.object)); + if m.computed { + let property = self.convert_expression(&m.property); + Expr::Member(MemberExpr { + span: self.span(&m.base), + obj: object, + prop: MemberProp::Computed(ComputedPropName { + span: DUMMY_SP, + expr: Box::new(property), + }), + }) + } else { + let prop_name = self.expression_to_ident_name(&m.property); + Expr::Member(MemberExpr { + span: self.span(&m.base), + obj: object, + prop: MemberProp::Ident(prop_name), + }) + } + } + + fn expression_to_ident_name(&self, expr: &BabelExpr) -> IdentName { + match expr { + BabelExpr::Identifier(id) => self.ident_name(&id.name, self.span(&id.base)), + _ => self.ident_name("__unknown__", DUMMY_SP), + } + } + + fn convert_arguments(&self, args: &[BabelExpr]) -> Vec { + args.iter().map(|a| self.convert_argument(a)).collect() + } + + fn convert_argument(&self, arg: &BabelExpr) -> ExprOrSpread { + match arg { + BabelExpr::SpreadElement(s) => ExprOrSpread { + spread: Some(self.span(&s.base)), + expr: Box::new(self.convert_expression(&s.argument)), + }, + _ => ExprOrSpread { + spread: None, + expr: Box::new(self.convert_expression(arg)), + }, + } + } + + fn convert_array_element(&self, elem: &Option) -> Option { + match elem { + None => None, + Some(BabelExpr::SpreadElement(s)) => Some(ExprOrSpread { + spread: Some(self.span(&s.base)), + expr: Box::new(self.convert_expression(&s.argument)), + }), + Some(e) => Some(ExprOrSpread { + spread: None, + expr: Box::new(self.convert_expression(e)), + }), + } + } + + fn convert_object_expression_property( + &self, + prop: &babel_expr::ObjectExpressionProperty, + ) -> PropOrSpread { + match prop { + babel_expr::ObjectExpressionProperty::ObjectProperty(p) => { + let key = if p.computed { + // Computed property key: [expr] + PropName::Computed(ComputedPropName { + span: DUMMY_SP, + expr: Box::new(self.convert_expression(&p.key)), + }) + } else { + self.convert_expression_to_prop_name(&p.key) + }; + let value = self.convert_expression(&p.value); + let method = p.method.unwrap_or(false); + + if p.shorthand { + PropOrSpread::Prop(Box::new(Prop::Shorthand(match &*p.key { + BabelExpr::Identifier(id) => self.ident(&id.name, self.span(&id.base)), + _ => self.ident("__unknown__", DUMMY_SP), + }))) + } else if method { + // Method shorthand: { foo() {} } + // The value should be a function expression + let func = match value { + Expr::Fn(fn_expr) => *fn_expr.function, + _ => { + // Fallback: wrap in a key-value + return PropOrSpread::Prop(Box::new(Prop::KeyValue(KeyValueProp { + key, + value: Box::new(value), + }))); + } + }; + PropOrSpread::Prop(Box::new(Prop::Method(MethodProp { + key, + function: Box::new(func), + }))) + } else { + PropOrSpread::Prop(Box::new(Prop::KeyValue(KeyValueProp { + key, + value: Box::new(value), + }))) + } + } + babel_expr::ObjectExpressionProperty::ObjectMethod(m) => { + let key = if m.computed { + PropName::Computed(ComputedPropName { + span: DUMMY_SP, + expr: Box::new(self.convert_expression(&m.key)), + }) + } else { + self.convert_expression_to_prop_name(&m.key) + }; + let func = self.convert_object_method_to_function(m); + match m.kind { + babel_expr::ObjectMethodKind::Get => { + PropOrSpread::Prop(Box::new(Prop::Getter(GetterProp { + span: self.span(&m.base), + key, + type_ann: None, + body: func.body, + }))) + } + babel_expr::ObjectMethodKind::Set => { + let param = func + .params + .into_iter() + .next() + .map(|p| Box::new(p.pat)) + .unwrap_or_else(|| { + Box::new(Pat::Ident(self.binding_ident("_", DUMMY_SP))) + }); + PropOrSpread::Prop(Box::new(Prop::Setter(SetterProp { + span: self.span(&m.base), + key, + this_param: None, + param, + body: func.body, + }))) + } + babel_expr::ObjectMethodKind::Method => { + PropOrSpread::Prop(Box::new(Prop::Method(MethodProp { + key, + function: Box::new(func), + }))) + } + } + } + babel_expr::ObjectExpressionProperty::SpreadElement(s) => { + PropOrSpread::Spread(SpreadElement { + dot3_token: self.span(&s.base), + expr: Box::new(self.convert_expression(&s.argument)), + }) + } + } + } + + fn convert_expression_to_prop_name(&self, expr: &BabelExpr) -> PropName { + match expr { + BabelExpr::Identifier(id) => { + PropName::Ident(self.ident_name(&id.name, self.span(&id.base))) + } + BabelExpr::StringLiteral(s) => PropName::Str(Str { + span: self.span(&s.base), + value: self.wtf8(&s.value), + raw: None, + }), + BabelExpr::NumericLiteral(n) => PropName::Num(Number { + span: self.span(&n.base), + value: n.value, + raw: None, + }), + _ => PropName::Computed(ComputedPropName { + span: DUMMY_SP, + expr: Box::new(self.convert_expression(expr)), + }), + } + } + + fn convert_template_literal(&self, tl: &babel_expr::TemplateLiteral) -> Tpl { + let quasis = tl + .quasis + .iter() + .map(|q| { + let cooked = q.value.cooked.as_ref().map(|c| self.wtf8(c)); + TplElement { + span: self.span(&q.base), + tail: q.tail, + cooked, + raw: self.atom(&q.value.raw), + } + }) + .collect(); + let exprs = tl + .expressions + .iter() + .map(|e| Box::new(self.convert_expression(e))) + .collect(); + Tpl { + span: self.span(&tl.base), + exprs, + quasis, + } + } + + // ===== Functions ===== + + fn convert_function_declaration(&self, f: &babel_stmt::FunctionDeclaration) -> FnDecl { + let ident = + f.id.as_ref() + .map(|id| self.ident(&id.name, self.span(&id.base))) + .unwrap_or_else(|| self.ident("_anonymous", DUMMY_SP)); + let params = self.convert_params(&f.params); + let body = Some(self.convert_block_statement(&f.body)); + let declare = f.declare.unwrap_or(false); + FnDecl { + ident, + declare, + function: Box::new(Function { + params, + decorators: vec![], + span: self.span(&f.base), + ctxt: SyntaxContext::empty(), + body, + is_generator: f.generator, + is_async: f.is_async, + type_params: None, + return_type: None, + }), + } + } + + fn convert_object_method_to_function(&self, m: &babel_expr::ObjectMethod) -> Function { + let params = self.convert_params(&m.params); + let body = Some(self.convert_block_statement(&m.body)); + Function { + params, + decorators: vec![], + span: self.span(&m.base), + ctxt: SyntaxContext::empty(), + body, + is_generator: m.generator, + is_async: m.is_async, + type_params: None, + return_type: None, + } + } + + fn convert_arrow_function(&self, arrow: &babel_expr::ArrowFunctionExpression) -> Expr { + let is_expression = arrow.expression.unwrap_or(false); + let params = arrow + .params + .iter() + .map(|p| self.convert_pattern(p)) + .collect(); + + let body: Box = match &*arrow.body { + babel_expr::ArrowFunctionBody::BlockStatement(block) => Box::new( + BlockStmtOrExpr::BlockStmt(self.convert_block_statement(block)), + ), + babel_expr::ArrowFunctionBody::Expression(expr) => { + if is_expression { + let converted = self.convert_expression(expr); + // Wrap object expressions in parens to prevent ambiguity + // with block bodies: `() => ({...})` vs `() => {...}` + let converted = if matches!(&converted, Expr::Object(_)) { + Expr::Paren(ParenExpr { + span: converted.span(), + expr: Box::new(converted), + }) + } else { + converted + }; + Box::new(BlockStmtOrExpr::Expr(Box::new(converted))) + } else { + // Wrap in block with return + let ret_stmt = Stmt::Return(ReturnStmt { + span: DUMMY_SP, + arg: Some(Box::new(self.convert_expression(expr))), + }); + Box::new(BlockStmtOrExpr::BlockStmt(BlockStmt { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + stmts: vec![ret_stmt], + })) + } + } + }; + + Expr::Arrow(ArrowExpr { + span: self.span(&arrow.base), + ctxt: SyntaxContext::empty(), + params, + body, + is_async: arrow.is_async, + is_generator: arrow.generator, + return_type: None, + type_params: None, + }) + } + + fn convert_params(&self, params: &[PatternLike]) -> Vec { + params + .iter() + .map(|p| Param { + span: DUMMY_SP, + decorators: vec![], + pat: self.convert_pattern(p), + }) + .collect() + } + + // ===== Patterns ===== + + fn convert_pattern(&self, pattern: &PatternLike) -> Pat { + match pattern { + PatternLike::Identifier(id) => { + let mut bi = self.binding_ident(&id.name, self.span(&id.base)); + bi.id.optional = id.optional.unwrap_or(false); + // Preserve type annotations if present + if let Some(ref type_ann) = id.type_annotation { + bi.type_ann = self.convert_ts_type_annotation_from_json(type_ann); + } + Pat::Ident(bi) + } + PatternLike::ObjectPattern(obj) => { + let mut props: Vec = Vec::new(); + + for prop in &obj.properties { + match prop { + ObjectPatternProperty::ObjectProperty(p) => { + if p.shorthand { + // Shorthand: { x } or { x = default } + let value = self.convert_pattern(&p.value); + match &*p.key { + BabelExpr::Identifier(id) => { + let key_ident = + self.binding_ident(&id.name, self.span(&id.base)); + match value { + Pat::Assign(assign_pat) => { + props.push(ObjectPatProp::Assign(AssignPatProp { + span: self.span(&p.base), + key: key_ident, + value: Some(assign_pat.right), + })); + } + _ => { + props.push(ObjectPatProp::Assign(AssignPatProp { + span: self.span(&p.base), + key: key_ident, + value: None, + })); + } + } + } + _ => { + // Fallback to key-value + let key = self.convert_expression_to_prop_name(&p.key); + props.push(ObjectPatProp::KeyValue(KeyValuePatProp { + key, + value: Box::new(value), + })); + } + } + } else { + let key = self.convert_expression_to_prop_name(&p.key); + let value = self.convert_pattern(&p.value); + props.push(ObjectPatProp::KeyValue(KeyValuePatProp { + key, + value: Box::new(value), + })); + } + } + ObjectPatternProperty::RestElement(r) => { + let arg = Box::new(self.convert_pattern(&r.argument)); + props.push(ObjectPatProp::Rest(RestPat { + span: self.span(&r.base), + dot3_token: self.span(&r.base), + arg, + type_ann: None, + })); + } + } + } + + Pat::Object(ObjectPat { + span: self.span(&obj.base), + props, + optional: false, + type_ann: None, + }) + } + PatternLike::ArrayPattern(arr) => { + let elems = arr + .elements + .iter() + .map(|e| e.as_ref().map(|p| self.convert_pattern(p))) + .collect(); + Pat::Array(ArrayPat { + span: self.span(&arr.base), + elems, + optional: false, + type_ann: None, + }) + } + PatternLike::AssignmentPattern(ap) => { + let left = Box::new(self.convert_pattern(&ap.left)); + let right = Box::new(self.convert_expression(&ap.right)); + Pat::Assign(AssignPat { + span: self.span(&ap.base), + left, + right, + }) + } + PatternLike::RestElement(r) => { + let arg = Box::new(self.convert_pattern(&r.argument)); + Pat::Rest(RestPat { + span: self.span(&r.base), + dot3_token: self.span(&r.base), + arg, + type_ann: None, + }) + } + PatternLike::MemberExpression(m) => { + // MemberExpression in pattern position - convert to an expression pattern + Pat::Expr(Box::new(self.convert_member_expression(m))) + } + } + } + + // ===== Patterns → AssignmentTarget ===== + + fn convert_pattern_to_assign_target(&self, pattern: &PatternLike) -> AssignTarget { + match pattern { + PatternLike::Identifier(id) => AssignTarget::Simple(SimpleAssignTarget::Ident( + self.binding_ident(&id.name, self.span(&id.base)), + )), + PatternLike::MemberExpression(m) => { + let expr = self.convert_member_expression(m); + match expr { + Expr::Member(member) => { + AssignTarget::Simple(SimpleAssignTarget::Member(member)) + } + _ => AssignTarget::Simple(SimpleAssignTarget::Ident( + self.binding_ident("__unknown__", DUMMY_SP), + )), + } + } + PatternLike::ObjectPattern(_obj) => { + let pat = self.convert_pattern(pattern); + match pat { + Pat::Object(obj_pat) => AssignTarget::Pat(AssignTargetPat::Object(obj_pat)), + _ => AssignTarget::Simple(SimpleAssignTarget::Ident( + self.binding_ident("__unknown__", DUMMY_SP), + )), + } + } + PatternLike::ArrayPattern(_arr) => { + let pat = self.convert_pattern(pattern); + match pat { + Pat::Array(arr_pat) => AssignTarget::Pat(AssignTargetPat::Array(arr_pat)), + _ => AssignTarget::Simple(SimpleAssignTarget::Ident( + self.binding_ident("__unknown__", DUMMY_SP), + )), + } + } + PatternLike::AssignmentPattern(ap) => { + // For assignment LHS, use the left side + self.convert_pattern_to_assign_target(&ap.left) + } + PatternLike::RestElement(r) => self.convert_pattern_to_assign_target(&r.argument), + } + } + + // ===== JSX ===== + + fn convert_jsx_element( + &self, + el: &react_compiler_ast::jsx::JSXElement, + ) -> swc_ecma_ast::JSXElement { + let opening = self.convert_jsx_opening_element(&el.opening_element); + let children: Vec = el + .children + .iter() + .map(|c| self.convert_jsx_child(c)) + .collect(); + let closing = el + .closing_element + .as_ref() + .map(|c| self.convert_jsx_closing_element(c)); + swc_ecma_ast::JSXElement { + span: self.span(&el.base), + opening, + children, + closing, + } + } + + fn convert_jsx_opening_element( + &self, + el: &react_compiler_ast::jsx::JSXOpeningElement, + ) -> swc_ecma_ast::JSXOpeningElement { + let name = self.convert_jsx_element_name(&el.name); + let attrs = el + .attributes + .iter() + .map(|a| self.convert_jsx_attribute_item(a)) + .collect(); + swc_ecma_ast::JSXOpeningElement { + span: self.span(&el.base), + name, + attrs, + self_closing: el.self_closing, + type_args: None, + } + } + + fn convert_jsx_closing_element( + &self, + el: &react_compiler_ast::jsx::JSXClosingElement, + ) -> swc_ecma_ast::JSXClosingElement { + let name = self.convert_jsx_element_name(&el.name); + swc_ecma_ast::JSXClosingElement { + span: self.span(&el.base), + name, + } + } + + fn convert_jsx_element_name( + &self, + name: &react_compiler_ast::jsx::JSXElementName, + ) -> swc_ecma_ast::JSXElementName { + match name { + react_compiler_ast::jsx::JSXElementName::JSXIdentifier(id) => { + swc_ecma_ast::JSXElementName::Ident(self.ident(&id.name, self.span(&id.base))) + } + react_compiler_ast::jsx::JSXElementName::JSXMemberExpression(m) => { + let member = self.convert_jsx_member_expression(m); + swc_ecma_ast::JSXElementName::JSXMemberExpr(member) + } + react_compiler_ast::jsx::JSXElementName::JSXNamespacedName(ns) => { + let namespace = self.ident_name(&ns.namespace.name, self.span(&ns.namespace.base)); + let name = self.ident_name(&ns.name.name, self.span(&ns.name.base)); + swc_ecma_ast::JSXElementName::JSXNamespacedName(swc_ecma_ast::JSXNamespacedName { + span: DUMMY_SP, + ns: namespace, + name, + }) + } + } + } + + fn convert_jsx_member_expression( + &self, + m: &react_compiler_ast::jsx::JSXMemberExpression, + ) -> swc_ecma_ast::JSXMemberExpr { + let obj = self.convert_jsx_member_expression_object(&m.object); + let prop = self.ident_name(&m.property.name, self.span(&m.property.base)); + swc_ecma_ast::JSXMemberExpr { + span: DUMMY_SP, + obj, + prop, + } + } + + fn convert_jsx_member_expression_object( + &self, + obj: &react_compiler_ast::jsx::JSXMemberExprObject, + ) -> swc_ecma_ast::JSXObject { + match obj { + react_compiler_ast::jsx::JSXMemberExprObject::JSXIdentifier(id) => { + swc_ecma_ast::JSXObject::Ident(self.ident(&id.name, self.span(&id.base))) + } + react_compiler_ast::jsx::JSXMemberExprObject::JSXMemberExpression(m) => { + let member = self.convert_jsx_member_expression(m); + swc_ecma_ast::JSXObject::JSXMemberExpr(Box::new(member)) + } + } + } + + fn convert_jsx_attribute_item( + &self, + item: &react_compiler_ast::jsx::JSXAttributeItem, + ) -> swc_ecma_ast::JSXAttrOrSpread { + match item { + react_compiler_ast::jsx::JSXAttributeItem::JSXAttribute(attr) => { + let name = self.convert_jsx_attribute_name(&attr.name); + let value = attr + .value + .as_ref() + .map(|v| self.convert_jsx_attribute_value(v)); + swc_ecma_ast::JSXAttrOrSpread::JSXAttr(swc_ecma_ast::JSXAttr { + span: self.span(&attr.base), + name, + value, + }) + } + react_compiler_ast::jsx::JSXAttributeItem::JSXSpreadAttribute(s) => { + swc_ecma_ast::JSXAttrOrSpread::SpreadElement(SpreadElement { + dot3_token: self.span(&s.base), + expr: Box::new(self.convert_expression(&s.argument)), + }) + } + } + } + + fn convert_jsx_attribute_name( + &self, + name: &react_compiler_ast::jsx::JSXAttributeName, + ) -> swc_ecma_ast::JSXAttrName { + match name { + react_compiler_ast::jsx::JSXAttributeName::JSXIdentifier(id) => { + swc_ecma_ast::JSXAttrName::Ident(self.ident_name(&id.name, self.span(&id.base))) + } + react_compiler_ast::jsx::JSXAttributeName::JSXNamespacedName(ns) => { + let namespace = self.ident_name(&ns.namespace.name, self.span(&ns.namespace.base)); + let name = self.ident_name(&ns.name.name, self.span(&ns.name.base)); + swc_ecma_ast::JSXAttrName::JSXNamespacedName(swc_ecma_ast::JSXNamespacedName { + span: DUMMY_SP, + ns: namespace, + name, + }) + } + } + } + + fn convert_jsx_attribute_value( + &self, + value: &react_compiler_ast::jsx::JSXAttributeValue, + ) -> swc_ecma_ast::JSXAttrValue { + match value { + react_compiler_ast::jsx::JSXAttributeValue::StringLiteral(s) => { + // For JSX attributes, if the value contains double quotes, + // use single quotes to avoid escaping issues that prettier + // can't parse (e.g., name="\"user\" name"). + let raw = if s.value.contains('"') { + Some(Atom::from(format!( + "'{}'", + s.value.replace('\\', "\\\\").replace('\'', "\\'") + ))) + } else { + self.escape_string_raw(&s.value) + }; + swc_ecma_ast::JSXAttrValue::Str(Str { + span: self.span(&s.base), + value: self.wtf8(&s.value), + raw, + }) + } + react_compiler_ast::jsx::JSXAttributeValue::JSXExpressionContainer(ec) => { + let expr = self.convert_jsx_expression_container_expr(&ec.expression); + swc_ecma_ast::JSXAttrValue::JSXExprContainer(swc_ecma_ast::JSXExprContainer { + span: self.span(&ec.base), + expr, + }) + } + react_compiler_ast::jsx::JSXAttributeValue::JSXElement(el) => { + let element = self.convert_jsx_element(el.as_ref()); + swc_ecma_ast::JSXAttrValue::JSXElement(Box::new(element)) + } + react_compiler_ast::jsx::JSXAttributeValue::JSXFragment(frag) => { + let fragment = self.convert_jsx_fragment(frag); + swc_ecma_ast::JSXAttrValue::JSXFragment(fragment) + } + } + } + + fn convert_jsx_expression_container_expr( + &self, + expr: &react_compiler_ast::jsx::JSXExpressionContainerExpr, + ) -> swc_ecma_ast::JSXExpr { + match expr { + react_compiler_ast::jsx::JSXExpressionContainerExpr::JSXEmptyExpression(e) => { + swc_ecma_ast::JSXExpr::JSXEmptyExpr(swc_ecma_ast::JSXEmptyExpr { + span: self.span(&e.base), + }) + } + react_compiler_ast::jsx::JSXExpressionContainerExpr::Expression(e) => { + swc_ecma_ast::JSXExpr::Expr(Box::new(self.convert_expression(e))) + } + } + } + + fn convert_jsx_child( + &self, + child: &react_compiler_ast::jsx::JSXChild, + ) -> swc_ecma_ast::JSXElementChild { + match child { + react_compiler_ast::jsx::JSXChild::JSXText(t) => { + swc_ecma_ast::JSXElementChild::JSXText(swc_ecma_ast::JSXText { + span: self.span(&t.base), + value: self.atom(&t.value), + raw: self.atom(&t.value), + }) + } + react_compiler_ast::jsx::JSXChild::JSXElement(el) => { + let element = self.convert_jsx_element(el.as_ref()); + swc_ecma_ast::JSXElementChild::JSXElement(Box::new(element)) + } + react_compiler_ast::jsx::JSXChild::JSXFragment(frag) => { + let fragment = self.convert_jsx_fragment(frag); + swc_ecma_ast::JSXElementChild::JSXFragment(fragment) + } + react_compiler_ast::jsx::JSXChild::JSXExpressionContainer(ec) => { + let expr = self.convert_jsx_expression_container_expr(&ec.expression); + swc_ecma_ast::JSXElementChild::JSXExprContainer(swc_ecma_ast::JSXExprContainer { + span: self.span(&ec.base), + expr, + }) + } + react_compiler_ast::jsx::JSXChild::JSXSpreadChild(s) => { + swc_ecma_ast::JSXElementChild::JSXSpreadChild(swc_ecma_ast::JSXSpreadChild { + span: self.span(&s.base), + expr: Box::new(self.convert_expression(&s.expression)), + }) + } + } + } + + fn convert_jsx_fragment( + &self, + frag: &react_compiler_ast::jsx::JSXFragment, + ) -> swc_ecma_ast::JSXFragment { + let children = frag + .children + .iter() + .map(|c| self.convert_jsx_child(c)) + .collect(); + swc_ecma_ast::JSXFragment { + span: self.span(&frag.base), + opening: swc_ecma_ast::JSXOpeningFragment { + span: self.span(&frag.opening_fragment.base), + }, + children, + closing: swc_ecma_ast::JSXClosingFragment { + span: self.span(&frag.closing_fragment.base), + }, + } + } + + // ===== Import/Export ===== + + fn convert_import_declaration(&self, decl: &ImportDeclaration) -> swc_ecma_ast::ImportDecl { + let specifiers = decl + .specifiers + .iter() + .map(|s| self.convert_import_specifier(s)) + .collect(); + let src = Box::new(Str { + span: self.span(&decl.source.base), + value: self.wtf8(&decl.source.value), + raw: None, + }); + let type_only = matches!(decl.import_kind.as_ref(), Some(ImportKind::Type)); + swc_ecma_ast::ImportDecl { + span: self.span(&decl.base), + specifiers, + src, + type_only, + with: None, + phase: Default::default(), + } + } + + fn convert_import_specifier( + &self, + spec: &react_compiler_ast::declarations::ImportSpecifier, + ) -> swc_ecma_ast::ImportSpecifier { + match spec { + react_compiler_ast::declarations::ImportSpecifier::ImportSpecifier(s) => { + let local = self.ident(&s.local.name, self.span(&s.local.base)); + // Only set `imported` if it differs from `local` — otherwise + // SWC emits `foo as foo` instead of just `foo`. + let imported_name = match &s.imported { + react_compiler_ast::declarations::ModuleExportName::Identifier(id) => { + Some(&id.name) + } + react_compiler_ast::declarations::ModuleExportName::StringLiteral(_) => None, + }; + let imported = if imported_name == Some(&s.local.name) { + None + } else { + Some(self.convert_module_export_name(&s.imported)) + }; + let is_type_only = matches!(s.import_kind.as_ref(), Some(ImportKind::Type)); + swc_ecma_ast::ImportSpecifier::Named(ImportNamedSpecifier { + span: self.span(&s.base), + local, + imported, + is_type_only, + }) + } + react_compiler_ast::declarations::ImportSpecifier::ImportDefaultSpecifier(s) => { + let local = self.ident(&s.local.name, self.span(&s.local.base)); + swc_ecma_ast::ImportSpecifier::Default(ImportDefaultSpecifier { + span: self.span(&s.base), + local, + }) + } + react_compiler_ast::declarations::ImportSpecifier::ImportNamespaceSpecifier(s) => { + let local = self.ident(&s.local.name, self.span(&s.local.base)); + swc_ecma_ast::ImportSpecifier::Namespace(ImportStarAsSpecifier { + span: self.span(&s.base), + local, + }) + } + } + } + + fn convert_module_export_name( + &self, + name: &react_compiler_ast::declarations::ModuleExportName, + ) -> swc_ecma_ast::ModuleExportName { + match name { + react_compiler_ast::declarations::ModuleExportName::Identifier(id) => { + swc_ecma_ast::ModuleExportName::Ident(self.ident(&id.name, self.span(&id.base))) + } + react_compiler_ast::declarations::ModuleExportName::StringLiteral(s) => { + swc_ecma_ast::ModuleExportName::Str(Str { + span: self.span(&s.base), + value: self.wtf8(&s.value), + raw: None, + }) + } + } + } + + fn convert_export_named_to_module_item(&self, decl: &ExportNamedDeclaration) -> ModuleItem { + // If there's a declaration, emit as ExportDecl + if let Some(declaration) = &decl.declaration { + let swc_decl = self.convert_declaration(declaration); + return ModuleItem::ModuleDecl(ModuleDecl::ExportDecl(ExportDecl { + span: self.span(&decl.base), + decl: swc_decl, + })); + } + self.convert_export_named_specifiers(decl) + } + + fn convert_declaration(&self, decl: &react_compiler_ast::declarations::Declaration) -> Decl { + match decl { + react_compiler_ast::declarations::Declaration::FunctionDeclaration(f) => { + Decl::Fn(self.convert_function_declaration(f)) + } + react_compiler_ast::declarations::Declaration::VariableDeclaration(v) => { + Decl::Var(Box::new(self.convert_variable_declaration(v))) + } + react_compiler_ast::declarations::Declaration::ClassDeclaration(c) => { + let ident = + c.id.as_ref() + .map(|id| self.ident(&id.name, self.span(&id.base))) + .unwrap_or_else(|| self.ident("_anonymous", DUMMY_SP)); + let super_class = c + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))); + Decl::Class(ClassDecl { + ident, + declare: c.declare.unwrap_or(false), + class: Box::new(Class { + span: self.span(&c.base), + ctxt: SyntaxContext::empty(), + decorators: vec![], + body: vec![], + super_class, + is_abstract: false, + type_params: None, + super_type_params: None, + implements: vec![], + }), + }) + } + _ => Decl::Var(Box::new(VarDecl { + span: DUMMY_SP, + ctxt: SyntaxContext::empty(), + kind: VarDeclKind::Const, + declare: true, + decls: vec![], + })), + } + } + + fn convert_export_named_specifiers(&self, decl: &ExportNamedDeclaration) -> ModuleItem { + let specifiers = decl + .specifiers + .iter() + .map(|s| self.convert_export_specifier(s)) + .collect(); + let src = decl.source.as_ref().map(|s| { + Box::new(Str { + span: self.span(&s.base), + value: self.wtf8(&s.value), + raw: None, + }) + }); + let type_only = matches!(decl.export_kind.as_ref(), Some(ExportKind::Type)); + + ModuleItem::ModuleDecl(ModuleDecl::ExportNamed(NamedExport { + span: self.span(&decl.base), + specifiers, + src, + type_only, + with: None, + })) + } + + fn convert_export_specifier( + &self, + spec: &react_compiler_ast::declarations::ExportSpecifier, + ) -> swc_ecma_ast::ExportSpecifier { + match spec { + react_compiler_ast::declarations::ExportSpecifier::ExportSpecifier(s) => { + let orig = self.convert_module_export_name(&s.local); + // Only set `exported` if it differs from `local` + let local_name = match &s.local { + react_compiler_ast::declarations::ModuleExportName::Identifier(id) => { + Some(&id.name) + } + _ => None, + }; + let exported_name = match &s.exported { + react_compiler_ast::declarations::ModuleExportName::Identifier(id) => { + Some(&id.name) + } + _ => None, + }; + let exported = if local_name.is_some() && local_name == exported_name { + None + } else { + Some(self.convert_module_export_name(&s.exported)) + }; + let is_type_only = matches!(s.export_kind.as_ref(), Some(ExportKind::Type)); + swc_ecma_ast::ExportSpecifier::Named(ExportNamedSpecifier { + span: self.span(&s.base), + orig, + exported, + is_type_only, + }) + } + react_compiler_ast::declarations::ExportSpecifier::ExportDefaultSpecifier(s) => { + swc_ecma_ast::ExportSpecifier::Default(swc_ecma_ast::ExportDefaultSpecifier { + exported: self.ident(&s.exported.name, self.span(&s.exported.base)), + }) + } + react_compiler_ast::declarations::ExportSpecifier::ExportNamespaceSpecifier(s) => { + let name = self.convert_module_export_name(&s.exported); + swc_ecma_ast::ExportSpecifier::Namespace(ExportNamespaceSpecifier { + span: self.span(&s.base), + name, + }) + } + } + } + + fn convert_export_default_to_module_item(&self, decl: &ExportDefaultDeclaration) -> ModuleItem { + let span = self.span(&decl.base); + match &*decl.declaration { + BabelExportDefaultDecl::FunctionDeclaration(f) => { + let fd = self.convert_function_declaration(f); + ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultDecl( + swc_ecma_ast::ExportDefaultDecl { + span, + decl: swc_ecma_ast::DefaultDecl::Fn(FnExpr { + ident: Some(fd.ident), + function: fd.function, + }), + }, + )) + } + BabelExportDefaultDecl::ClassDeclaration(c) => { + let ident = + c.id.as_ref() + .map(|id| self.ident(&id.name, self.span(&id.base))); + let super_class = c + .super_class + .as_ref() + .map(|s| Box::new(self.convert_expression(s))); + ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultDecl( + swc_ecma_ast::ExportDefaultDecl { + span, + decl: swc_ecma_ast::DefaultDecl::Class(ClassExpr { + ident, + class: Box::new(Class { + span, + ctxt: SyntaxContext::empty(), + decorators: vec![], + body: vec![], + super_class, + is_abstract: false, + type_params: None, + super_type_params: None, + implements: vec![], + }), + }), + }, + )) + } + BabelExportDefaultDecl::Expression(e) => { + ModuleItem::ModuleDecl(ModuleDecl::ExportDefaultExpr(ExportDefaultExpr { + span, + expr: Box::new(self.convert_expression(e)), + })) + } + } + } + + fn convert_export_all_declaration( + &self, + decl: &ExportAllDeclaration, + ) -> swc_ecma_ast::ExportAll { + let src = Box::new(Str { + span: self.span(&decl.source.base), + value: self.wtf8(&decl.source.value), + raw: None, + }); + let type_only = matches!(decl.export_kind.as_ref(), Some(ExportKind::Type)); + swc_ecma_ast::ExportAll { + span: self.span(&decl.base), + src, + type_only, + with: None, + } + } + + // ===== TS type helpers ===== + + /// Convert a Babel TSTypeAnnotation JSON to an SWC TsTypeAnnotation. + /// Returns None if the JSON is not a valid type annotation. + fn convert_ts_type_annotation_from_json( + &self, + json: &serde_json::Value, + ) -> Option> { + let type_name = json.get("type")?.as_str()?; + if type_name != "TSTypeAnnotation" && type_name != "TypeAnnotation" { + return None; + } + let type_annotation = json.get("typeAnnotation")?; + let ts_type = self.convert_ts_type_from_json(type_annotation, DUMMY_SP); + Some(Box::new(TsTypeAnn { + span: DUMMY_SP, + type_ann: Box::new(ts_type), + })) + } + + /// Convert a JSON-serialized TypeScript type annotation to an SWC TsType. + /// This handles common cases from the compiler's output. For unrecognized + /// types, it falls back to `any`. + fn convert_ts_type_from_json(&self, json: &serde_json::Value, span: Span) -> TsType { + let type_name = json.get("type").and_then(|v| v.as_str()).unwrap_or(""); + match type_name { + "TSTypeReference" => { + let name = json + .get("typeName") + .and_then(|tn| tn.get("name")) + .and_then(|n| n.as_str()) + .unwrap_or("unknown"); + if name == "const" { + TsType::TsTypeRef(TsTypeRef { + span, + type_name: TsEntityName::Ident(self.ident("const", span)), + type_params: None, + }) + } else { + TsType::TsTypeRef(TsTypeRef { + span, + type_name: TsEntityName::Ident(self.ident(name, span)), + type_params: None, + }) + } + } + "TSNumberKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsNumberKeyword, + }), + "TSStringKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsStringKeyword, + }), + "TSBooleanKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsBooleanKeyword, + }), + "TSVoidKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsVoidKeyword, + }), + "TSNullKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsNullKeyword, + }), + "TSUndefinedKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsUndefinedKeyword, + }), + "TSAnyKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + }), + "TSNeverKeyword" => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsNeverKeyword, + }), + "TSUnionType" => { + let types = json + .get("types") + .and_then(|t| t.as_array()) + .map(|arr| { + arr.iter() + .map(|t| Box::new(self.convert_ts_type_from_json(t, span))) + .collect::>() + }) + .unwrap_or_default(); + TsType::TsUnionOrIntersectionType(TsUnionOrIntersectionType::TsUnionType( + TsUnionType { span, types }, + )) + } + "TSIntersectionType" => { + let types = json + .get("types") + .and_then(|t| t.as_array()) + .map(|arr| { + arr.iter() + .map(|t| Box::new(self.convert_ts_type_from_json(t, span))) + .collect::>() + }) + .unwrap_or_default(); + TsType::TsUnionOrIntersectionType(TsUnionOrIntersectionType::TsIntersectionType( + TsIntersectionType { span, types }, + )) + } + "TSLiteralType" => { + if let Some(literal) = json.get("literal") { + let lit_type = literal.get("type").and_then(|t| t.as_str()).unwrap_or(""); + match lit_type { + "StringLiteral" => { + let value = literal.get("value").and_then(|v| v.as_str()).unwrap_or(""); + TsType::TsLitType(TsLitType { + span, + lit: TsLit::Str(Str { + span, + value: self.wtf8(value), + raw: None, + }), + }) + } + "NumericLiteral" => { + let value = + literal.get("value").and_then(|v| v.as_f64()).unwrap_or(0.0); + TsType::TsLitType(TsLitType { + span, + lit: TsLit::Number(Number { + span, + value, + raw: None, + }), + }) + } + "BooleanLiteral" => { + let value = literal + .get("value") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + TsType::TsLitType(TsLitType { + span, + lit: TsLit::Bool(Bool { span, value }), + }) + } + _ => TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + }), + } + } else { + TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + }) + } + } + "TSArrayType" => { + let elem = json + .get("elementType") + .map(|t| self.convert_ts_type_from_json(t, span)) + .unwrap_or(TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + })); + TsType::TsArrayType(TsArrayType { + span, + elem_type: Box::new(elem), + }) + } + "TSFunctionType" + | "TSTypeLiteral" + | "TSParenthesizedType" + | "TSTupleType" + | "TSOptionalType" + | "TSRestType" + | "TSConditionalType" + | "TSInferType" + | "TSMappedType" + | "TSIndexedAccessType" + | "TSTypeOperator" + | "TSTypePredicate" + | "TSImportType" + | "TSQualifiedName" => { + // For complex types, try to extract from source text + if let (Some(source), Some(start), Some(end)) = ( + self.source_text.as_deref(), + json.get("start").and_then(|v| v.as_u64()), + json.get("end").and_then(|v| v.as_u64()), + ) { + let start_idx = (start as usize).saturating_sub(1); + let end_idx = (end as usize).saturating_sub(1); + if start_idx < source.len() && end_idx <= source.len() && start_idx < end_idx { + let text = &source[start_idx..end_idx]; + // Parse the type using SWC + let wrapper = format!("type __T = {text};"); + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let fm = cm.new_source_file( + swc_common::sync::Lrc::new(swc_common::FileName::Anon), + wrapper, + ); + let mut errors = vec![]; + if let Ok(module) = swc_ecma_parser::parse_file_as_module( + &fm, + swc_ecma_parser::Syntax::Typescript(swc_ecma_parser::TsSyntax { + tsx: true, + ..Default::default() + }), + swc_ecma_ast::EsVersion::latest(), + None, + &mut errors, + ) { + if let Some(ModuleItem::Stmt(Stmt::Decl(Decl::TsTypeAlias(alias)))) = + module.body.into_iter().next() + { + return *alias.type_ann; + } + } + } + } + // Fallback + TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + }) + } + // Flow types + "NumberTypeAnnotation" + | "StringTypeAnnotation" + | "BooleanTypeAnnotation" + | "VoidTypeAnnotation" + | "NullLiteralTypeAnnotation" + | "AnyTypeAnnotation" + | "GenericTypeAnnotation" + | "UnionTypeAnnotation" + | "IntersectionTypeAnnotation" + | "NullableTypeAnnotation" + | "FunctionTypeAnnotation" + | "ObjectTypeAnnotation" + | "ArrayTypeAnnotation" + | "TupleTypeAnnotation" + | "TypeofTypeAnnotation" + | "NumberLiteralTypeAnnotation" + | "StringLiteralTypeAnnotation" + | "BooleanLiteralTypeAnnotation" => { + // For Flow types, try to extract from source text + if let (Some(source), Some(start), Some(end)) = ( + self.source_text.as_deref(), + json.get("start").and_then(|v| v.as_u64()), + json.get("end").and_then(|v| v.as_u64()), + ) { + let start_idx = (start as usize).saturating_sub(1); + let end_idx = (end as usize).saturating_sub(1); + if start_idx < source.len() && end_idx <= source.len() && start_idx < end_idx { + let text = &source[start_idx..end_idx]; + // For Flow types, we can use TS parser as many simple types + // have the same syntax + let wrapper = format!("type __T = {text};"); + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let fm = cm.new_source_file( + swc_common::sync::Lrc::new(swc_common::FileName::Anon), + wrapper, + ); + let mut errors = vec![]; + if let Ok(module) = swc_ecma_parser::parse_file_as_module( + &fm, + swc_ecma_parser::Syntax::Typescript(swc_ecma_parser::TsSyntax { + tsx: true, + ..Default::default() + }), + swc_ecma_ast::EsVersion::latest(), + None, + &mut errors, + ) { + if let Some(ModuleItem::Stmt(Stmt::Decl(Decl::TsTypeAlias(alias)))) = + module.body.into_iter().next() + { + return *alias.type_ann; + } + } + } + } + // Fallback + TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + }) + } + _ => { + // Fallback: emit `any` type + TsType::TsKeywordType(TsKeywordType { + span, + kind: TsKeywordTypeKind::TsAnyKeyword, + }) + } + } + } + + // ===== Operators ===== + + fn convert_binary_operator(&self, op: &BinaryOperator) -> BinaryOp { + match op { + BinaryOperator::Add => BinaryOp::Add, + BinaryOperator::Sub => BinaryOp::Sub, + BinaryOperator::Mul => BinaryOp::Mul, + BinaryOperator::Div => BinaryOp::Div, + BinaryOperator::Rem => BinaryOp::Mod, + BinaryOperator::Exp => BinaryOp::Exp, + BinaryOperator::Eq => BinaryOp::EqEq, + BinaryOperator::StrictEq => BinaryOp::EqEqEq, + BinaryOperator::Neq => BinaryOp::NotEq, + BinaryOperator::StrictNeq => BinaryOp::NotEqEq, + BinaryOperator::Lt => BinaryOp::Lt, + BinaryOperator::Lte => BinaryOp::LtEq, + BinaryOperator::Gt => BinaryOp::Gt, + BinaryOperator::Gte => BinaryOp::GtEq, + BinaryOperator::Shl => BinaryOp::LShift, + BinaryOperator::Shr => BinaryOp::RShift, + BinaryOperator::UShr => BinaryOp::ZeroFillRShift, + BinaryOperator::BitOr => BinaryOp::BitOr, + BinaryOperator::BitXor => BinaryOp::BitXor, + BinaryOperator::BitAnd => BinaryOp::BitAnd, + BinaryOperator::In => BinaryOp::In, + BinaryOperator::Instanceof => BinaryOp::InstanceOf, + BinaryOperator::Pipeline => BinaryOp::BitOr, // no pipeline in SWC + } + } + + fn convert_logical_operator(&self, op: &LogicalOperator) -> BinaryOp { + match op { + LogicalOperator::Or => BinaryOp::LogicalOr, + LogicalOperator::And => BinaryOp::LogicalAnd, + LogicalOperator::NullishCoalescing => BinaryOp::NullishCoalescing, + } + } + + fn convert_unary_operator(&self, op: &UnaryOperator) -> UnaryOp { + match op { + UnaryOperator::Neg => UnaryOp::Minus, + UnaryOperator::Plus => UnaryOp::Plus, + UnaryOperator::Not => UnaryOp::Bang, + UnaryOperator::BitNot => UnaryOp::Tilde, + UnaryOperator::TypeOf => UnaryOp::TypeOf, + UnaryOperator::Void => UnaryOp::Void, + UnaryOperator::Delete => UnaryOp::Delete, + UnaryOperator::Throw => UnaryOp::Void, // no throw-as-unary in SWC + } + } + + fn convert_update_operator(&self, op: &UpdateOperator) -> UpdateOp { + match op { + UpdateOperator::Increment => UpdateOp::PlusPlus, + UpdateOperator::Decrement => UpdateOp::MinusMinus, + } + } + + fn convert_assignment_operator(&self, op: &AssignmentOperator) -> AssignOp { + match op { + AssignmentOperator::Assign => AssignOp::Assign, + AssignmentOperator::AddAssign => AssignOp::AddAssign, + AssignmentOperator::SubAssign => AssignOp::SubAssign, + AssignmentOperator::MulAssign => AssignOp::MulAssign, + AssignmentOperator::DivAssign => AssignOp::DivAssign, + AssignmentOperator::RemAssign => AssignOp::ModAssign, + AssignmentOperator::ExpAssign => AssignOp::ExpAssign, + AssignmentOperator::ShlAssign => AssignOp::LShiftAssign, + AssignmentOperator::ShrAssign => AssignOp::RShiftAssign, + AssignmentOperator::UShrAssign => AssignOp::ZeroFillRShiftAssign, + AssignmentOperator::BitOrAssign => AssignOp::BitOrAssign, + AssignmentOperator::BitXorAssign => AssignOp::BitXorAssign, + AssignmentOperator::BitAndAssign => AssignOp::BitAndAssign, + AssignmentOperator::OrAssign => AssignOp::OrAssign, + AssignmentOperator::AndAssign => AssignOp::AndAssign, + AssignmentOperator::NullishAssign => AssignOp::NullishAssign, + } + } +} diff --git a/crates/swc_ecma_react_compiler/src/convert_scope.rs b/crates/swc_ecma_react_compiler/src/convert_scope.rs new file mode 100644 index 000000000000..4e6ce97c26f0 --- /dev/null +++ b/crates/swc_ecma_react_compiler/src/convert_scope.rs @@ -0,0 +1,948 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use std::collections::{HashMap, HashSet}; + +use indexmap::IndexMap; +use react_compiler_ast::scope::*; +use swc_ecma_ast::*; +use swc_ecma_visit::{Visit, VisitWith}; + +/// Helper to convert an SWC `Str` node's value to a Rust String. +/// `Str.value` is a `Wtf8Atom` which doesn't implement `Display`, +/// so we go through `Atom` via lossy conversion. +fn str_value_to_string(s: &Str) -> String { + s.value.to_atom_lossy().to_string() +} + +/// Build scope information from an SWC Module AST. +/// +/// This performs two passes over the AST: +/// 1. Build the scope tree and collect all bindings +/// 2. Resolve identifier references to their bindings +pub fn build_scope_info(module: &Module) -> ScopeInfo { + // Pass 1: Build scope tree and collect bindings + let mut collector = ScopeCollector::new(); + collector.visit_module(module); + + // Pass 2: Resolve references + // We scope the resolver borrow so we can move out of collector afterwards. + let reference_to_binding = { + let mut resolver = ReferenceResolver::new(&collector); + resolver.visit_module(module); + + // Also map declaration identifiers to their bindings + for binding in &collector.bindings { + if let Some(start) = binding.declaration_start { + resolver + .reference_to_binding + .entry(start) + .or_insert(binding.id); + } + } + + resolver.reference_to_binding + }; + + ScopeInfo { + scopes: collector.scopes, + bindings: collector.bindings, + node_to_scope: collector.node_to_scope, + reference_to_binding, + program_scope: ScopeId(0), + } +} + +// ── Pass 1: Scope tree + binding collection ───────────────────────────────── + +struct ScopeCollector { + scopes: Vec, + bindings: Vec, + node_to_scope: HashMap, + /// Stack of scope IDs representing the current nesting. + scope_stack: Vec, + /// Set of span starts for block statements that are direct function/catch + /// bodies. These should NOT create a separate Block scope. + function_body_spans: HashSet, +} + +impl ScopeCollector { + fn new() -> Self { + Self { + scopes: Vec::new(), + bindings: Vec::new(), + node_to_scope: HashMap::new(), + scope_stack: Vec::new(), + function_body_spans: HashSet::new(), + } + } + + fn current_scope(&self) -> ScopeId { + *self.scope_stack.last().expect("scope stack is empty") + } + + fn push_scope(&mut self, kind: ScopeKind, node_start: u32) -> ScopeId { + let id = ScopeId(self.scopes.len() as u32); + let parent = self.scope_stack.last().copied(); + self.scopes.push(ScopeData { + id, + parent, + kind, + bindings: HashMap::new(), + }); + self.node_to_scope.insert(node_start, id); + self.scope_stack.push(id); + id + } + + fn pop_scope(&mut self) { + self.scope_stack.pop(); + } + + /// Find the nearest enclosing function or program scope (for hoisting `var` + /// and function decls). + fn enclosing_function_scope(&self) -> ScopeId { + for &scope_id in self.scope_stack.iter().rev() { + let scope = &self.scopes[scope_id.0 as usize]; + match scope.kind { + ScopeKind::Function | ScopeKind::Program => return scope_id, + _ => {} + } + } + ScopeId(0) + } + + fn add_binding( + &mut self, + name: String, + kind: BindingKind, + scope: ScopeId, + declaration_type: String, + declaration_start: Option, + import: Option, + ) -> BindingId { + let id = BindingId(self.bindings.len() as u32); + self.bindings.push(BindingData { + id, + name: name.clone(), + kind, + scope, + declaration_type, + declaration_start, + import, + }); + self.scopes[scope.0 as usize].bindings.insert(name, id); + id + } + + /// Extract all binding identifiers from a pattern, adding each as a + /// binding. + fn collect_pat_bindings( + &mut self, + pat: &Pat, + kind: BindingKind, + scope: ScopeId, + declaration_type: &str, + ) { + match pat { + Pat::Ident(binding_ident) => { + let name = binding_ident.id.sym.to_string(); + let start = binding_ident.id.span.lo.0; + self.add_binding( + name, + kind, + scope, + declaration_type.to_string(), + Some(start), + None, + ); + } + Pat::Array(arr) => { + for p in arr.elems.iter().flatten() { + self.collect_pat_bindings(p, kind.clone(), scope, declaration_type); + } + } + Pat::Object(obj) => { + for prop in &obj.props { + match prop { + ObjectPatProp::KeyValue(kv) => { + self.collect_pat_bindings( + &kv.value, + kind.clone(), + scope, + declaration_type, + ); + } + ObjectPatProp::Assign(assign) => { + let name = assign.key.sym.to_string(); + let start = assign.key.span.lo.0; + self.add_binding( + name, + kind.clone(), + scope, + declaration_type.to_string(), + Some(start), + None, + ); + } + ObjectPatProp::Rest(rest) => { + self.collect_pat_bindings( + &rest.arg, + kind.clone(), + scope, + declaration_type, + ); + } + } + } + } + Pat::Rest(rest) => { + self.collect_pat_bindings(&rest.arg, kind, scope, declaration_type); + } + Pat::Assign(assign) => { + self.collect_pat_bindings(&assign.left, kind, scope, declaration_type); + } + Pat::Expr(_) | Pat::Invalid(_) => {} + } + } + + /// Visit a function's internals (params + body), creating the function + /// scope. Used for method definitions and other Function nodes not + /// covered by FnDecl/FnExpr. + fn visit_function_inner(&mut self, function: &Function) { + let func_start = function.span.lo.0; + self.push_scope(ScopeKind::Function, func_start); + + for param in &function.params { + self.collect_pat_bindings( + ¶m.pat, + BindingKind::Param, + self.current_scope(), + "FormalParameter", + ); + } + + if let Some(body) = &function.body { + self.function_body_spans.insert(body.span.lo.0); + body.visit_with(self); + } + + self.pop_scope(); + } +} + +impl Visit for ScopeCollector { + fn visit_module(&mut self, module: &Module) { + self.push_scope(ScopeKind::Program, module.span.lo.0); + module.visit_children_with(self); + self.pop_scope(); + } + + fn visit_import_decl(&mut self, import: &ImportDecl) { + let source = str_value_to_string(&import.src); + let program_scope = ScopeId(0); + + for spec in &import.specifiers { + match spec { + ImportSpecifier::Named(named) => { + let local_name = named.local.sym.to_string(); + let start = named.local.span.lo.0; + let imported_name = match &named.imported { + Some(ModuleExportName::Ident(ident)) => Some(ident.sym.to_string()), + Some(ModuleExportName::Str(s)) => Some(str_value_to_string(s)), + None => Some(local_name.clone()), + }; + self.add_binding( + local_name, + BindingKind::Module, + program_scope, + "ImportSpecifier".to_string(), + Some(start), + Some(ImportBindingData { + source: source.clone(), + kind: ImportBindingKind::Named, + imported: imported_name, + }), + ); + } + ImportSpecifier::Default(default) => { + let local_name = default.local.sym.to_string(); + let start = default.local.span.lo.0; + self.add_binding( + local_name, + BindingKind::Module, + program_scope, + "ImportDefaultSpecifier".to_string(), + Some(start), + Some(ImportBindingData { + source: source.clone(), + kind: ImportBindingKind::Default, + imported: None, + }), + ); + } + ImportSpecifier::Namespace(ns) => { + let local_name = ns.local.sym.to_string(); + let start = ns.local.span.lo.0; + self.add_binding( + local_name, + BindingKind::Module, + program_scope, + "ImportNamespaceSpecifier".to_string(), + Some(start), + Some(ImportBindingData { + source: source.clone(), + kind: ImportBindingKind::Namespace, + imported: None, + }), + ); + } + } + } + } + + fn visit_var_decl(&mut self, var_decl: &VarDecl) { + let (kind, declaration_type) = match var_decl.kind { + VarDeclKind::Var => (BindingKind::Var, "VariableDeclarator"), + VarDeclKind::Let => (BindingKind::Let, "VariableDeclarator"), + VarDeclKind::Const => (BindingKind::Const, "VariableDeclarator"), + }; + + let target_scope = match var_decl.kind { + VarDeclKind::Var => self.enclosing_function_scope(), + VarDeclKind::Let | VarDeclKind::Const => self.current_scope(), + }; + + for declarator in &var_decl.decls { + self.collect_pat_bindings( + &declarator.name, + kind.clone(), + target_scope, + declaration_type, + ); + // Visit initializers so nested functions/arrows get their scopes + if let Some(init) = &declarator.init { + init.visit_with(self); + } + } + } + + fn visit_fn_decl(&mut self, fn_decl: &FnDecl) { + // Function declarations are hoisted to the enclosing function/program scope + let hoist_scope = self.enclosing_function_scope(); + let name = fn_decl.ident.sym.to_string(); + let start = fn_decl.ident.span.lo.0; + self.add_binding( + name, + BindingKind::Hoisted, + hoist_scope, + "FunctionDeclaration".to_string(), + Some(start), + None, + ); + + self.visit_function_inner(&fn_decl.function); + } + + fn visit_export_default_decl(&mut self, decl: &ExportDefaultDecl) { + // For `export default function foo(...)`, the function name should be + // hoisted to the enclosing scope (like FnDecl), not bound only in the + // function's own scope (like FnExpr). + match &decl.decl { + DefaultDecl::Fn(fn_expr) => { + if let Some(ident) = &fn_expr.ident { + let hoist_scope = self.enclosing_function_scope(); + let name = ident.sym.to_string(); + let start = ident.span.lo.0; + self.add_binding( + name, + BindingKind::Hoisted, + hoist_scope, + "FunctionDeclaration".to_string(), + Some(start), + None, + ); + } + self.visit_function_inner(&fn_expr.function); + } + DefaultDecl::Class(class_expr) => { + if let Some(ident) = &class_expr.ident { + let name = ident.sym.to_string(); + let start = ident.span.lo.0; + self.add_binding( + name, + BindingKind::Local, + self.current_scope(), + "ClassDeclaration".to_string(), + Some(start), + None, + ); + } + self.push_scope(ScopeKind::Class, class_expr.class.span.lo.0); + class_expr.class.visit_children_with(self); + self.pop_scope(); + } + DefaultDecl::TsInterfaceDecl(d) => { + d.visit_with(self); + } + } + } + + fn visit_fn_expr(&mut self, fn_expr: &FnExpr) { + let func_start = fn_expr.function.span.lo.0; + self.push_scope(ScopeKind::Function, func_start); + + // Named function expressions bind their name in the function scope + if let Some(ident) = &fn_expr.ident { + let name = ident.sym.to_string(); + let start = ident.span.lo.0; + self.add_binding( + name, + BindingKind::Local, + self.current_scope(), + "FunctionExpression".to_string(), + Some(start), + None, + ); + } + + for param in &fn_expr.function.params { + self.collect_pat_bindings( + ¶m.pat, + BindingKind::Param, + self.current_scope(), + "FormalParameter", + ); + } + + if let Some(body) = &fn_expr.function.body { + self.function_body_spans.insert(body.span.lo.0); + body.visit_with(self); + } + + self.pop_scope(); + } + + fn visit_arrow_expr(&mut self, arrow: &ArrowExpr) { + let func_start = arrow.span.lo.0; + self.push_scope(ScopeKind::Function, func_start); + + for param in &arrow.params { + self.collect_pat_bindings( + param, + BindingKind::Param, + self.current_scope(), + "FormalParameter", + ); + } + + match &*arrow.body { + BlockStmtOrExpr::BlockStmt(block) => { + self.function_body_spans.insert(block.span.lo.0); + block.visit_with(self); + } + BlockStmtOrExpr::Expr(expr) => { + expr.visit_with(self); + } + } + + self.pop_scope(); + } + + fn visit_block_stmt(&mut self, block: &BlockStmt) { + if self.function_body_spans.remove(&block.span.lo.0) { + // This block is a function/catch body — don't create a separate scope + block.visit_children_with(self); + } else { + self.push_scope(ScopeKind::Block, block.span.lo.0); + block.visit_children_with(self); + self.pop_scope(); + } + } + + fn visit_for_stmt(&mut self, for_stmt: &ForStmt) { + self.push_scope(ScopeKind::For, for_stmt.span.lo.0); + + if let Some(init) = &for_stmt.init { + init.visit_with(self); + } + if let Some(test) = &for_stmt.test { + test.visit_with(self); + } + if let Some(update) = &for_stmt.update { + update.visit_with(self); + } + for_stmt.body.visit_with(self); + + self.pop_scope(); + } + + fn visit_for_in_stmt(&mut self, for_in: &ForInStmt) { + self.push_scope(ScopeKind::For, for_in.span.lo.0); + for_in.left.visit_with(self); + for_in.right.visit_with(self); + for_in.body.visit_with(self); + self.pop_scope(); + } + + fn visit_for_of_stmt(&mut self, for_of: &ForOfStmt) { + self.push_scope(ScopeKind::For, for_of.span.lo.0); + for_of.left.visit_with(self); + for_of.right.visit_with(self); + for_of.body.visit_with(self); + self.pop_scope(); + } + + fn visit_catch_clause(&mut self, catch: &CatchClause) { + self.push_scope(ScopeKind::Catch, catch.span.lo.0); + + if let Some(param) = &catch.param { + self.collect_pat_bindings(param, BindingKind::Let, self.current_scope(), "CatchClause"); + } + + // Mark catch body as already scoped (the catch scope covers it) + self.function_body_spans.insert(catch.body.span.lo.0); + catch.body.visit_with(self); + + self.pop_scope(); + } + + fn visit_switch_stmt(&mut self, switch: &SwitchStmt) { + // Visit the discriminant in the outer scope + switch.discriminant.visit_with(self); + + self.push_scope(ScopeKind::Switch, switch.span.lo.0); + for case in &switch.cases { + case.visit_with(self); + } + self.pop_scope(); + } + + fn visit_class_decl(&mut self, class_decl: &ClassDecl) { + let name = class_decl.ident.sym.to_string(); + let start = class_decl.ident.span.lo.0; + self.add_binding( + name, + BindingKind::Local, + self.current_scope(), + "ClassDeclaration".to_string(), + Some(start), + None, + ); + + self.push_scope(ScopeKind::Class, class_decl.class.span.lo.0); + class_decl.class.visit_children_with(self); + self.pop_scope(); + } + + fn visit_class_expr(&mut self, class_expr: &ClassExpr) { + self.push_scope(ScopeKind::Class, class_expr.class.span.lo.0); + + if let Some(ident) = &class_expr.ident { + let name = ident.sym.to_string(); + let start = ident.span.lo.0; + self.add_binding( + name, + BindingKind::Local, + self.current_scope(), + "ClassExpression".to_string(), + Some(start), + None, + ); + } + + class_expr.class.visit_children_with(self); + self.pop_scope(); + } + + // Method definitions contain a Function node. We intercept here + // so that the Function gets its own scope with params. + fn visit_function(&mut self, f: &Function) { + // This is reached for object/class methods via default traversal. + self.visit_function_inner(f); + } +} + +// ── Pass 2: Reference resolution ──────────────────────────────────────────── + +struct ReferenceResolver<'a> { + scopes: &'a [ScopeData], + #[allow(dead_code)] + bindings: &'a [BindingData], + node_to_scope: &'a HashMap, + reference_to_binding: IndexMap, + /// Stack of scope IDs for resolution + scope_stack: Vec, + /// Declaration positions to skip (these are binding sites, not references) + declaration_starts: HashSet, + /// Span starts for block statements that are direct function/catch bodies. + function_body_spans: HashSet, +} + +impl<'a> ReferenceResolver<'a> { + fn new(collector: &'a ScopeCollector) -> Self { + let mut declaration_starts = HashSet::new(); + for binding in &collector.bindings { + if let Some(start) = binding.declaration_start { + declaration_starts.insert(start); + } + } + Self { + scopes: &collector.scopes, + bindings: &collector.bindings, + node_to_scope: &collector.node_to_scope, + reference_to_binding: IndexMap::new(), + scope_stack: Vec::new(), + declaration_starts, + function_body_spans: HashSet::new(), + } + } + + fn current_scope(&self) -> ScopeId { + *self.scope_stack.last().expect("scope stack is empty") + } + + fn resolve_ident(&mut self, name: &str, start: u32) { + // Skip declaration sites — they'll be added separately + if self.declaration_starts.contains(&start) { + return; + } + + // Walk up the scope chain to find the binding + let mut current = Some(self.current_scope()); + while let Some(scope_id) = current { + let scope = &self.scopes[scope_id.0 as usize]; + if let Some(&binding_id) = scope.bindings.get(name) { + self.reference_to_binding.insert(start, binding_id); + return; + } + current = scope.parent; + } + // Not found — it's a global, don't record it + } + + fn find_scope_at(&self, node_start: u32) -> Option<&ScopeId> { + self.node_to_scope.get(&node_start) + } + + /// Visit a pattern in parameter position: skip binding idents, but visit + /// default values and computed keys as references. + fn visit_param_pattern(&mut self, pat: &Pat) { + match pat { + Pat::Ident(_) => { + // Declaration — skip + } + Pat::Array(arr) => { + for p in arr.elems.iter().flatten() { + self.visit_param_pattern(p); + } + } + Pat::Object(obj) => { + for prop in &obj.props { + match prop { + ObjectPatProp::KeyValue(kv) => { + if let PropName::Computed(computed) = &kv.key { + computed.visit_with(self); + } + self.visit_param_pattern(&kv.value); + } + ObjectPatProp::Assign(assign) => { + if let Some(value) = &assign.value { + value.visit_with(self); + } + } + ObjectPatProp::Rest(rest) => { + self.visit_param_pattern(&rest.arg); + } + } + } + } + Pat::Assign(assign) => { + self.visit_param_pattern(&assign.left); + // Default value IS a reference + assign.right.visit_with(self); + } + Pat::Rest(rest) => { + self.visit_param_pattern(&rest.arg); + } + Pat::Expr(expr) => { + expr.visit_with(self); + } + Pat::Invalid(_) => {} + } + } + + /// Visit function internals for the resolver (params + body), mirroring the + /// collector. + fn visit_function_inner(&mut self, function: &Function) { + let func_start = function.span.lo.0; + if let Some(&scope_id) = self.find_scope_at(func_start) { + self.scope_stack.push(scope_id); + + for param in &function.params { + self.visit_param_pattern(¶m.pat); + } + + if let Some(body) = &function.body { + self.function_body_spans.insert(body.span.lo.0); + body.visit_with(self); + } + + self.scope_stack.pop(); + } + } +} + +impl<'a> Visit for ReferenceResolver<'a> { + fn visit_module(&mut self, module: &Module) { + self.scope_stack.push(ScopeId(0)); + module.visit_children_with(self); + self.scope_stack.pop(); + } + + fn visit_ident(&mut self, ident: &Ident) { + let name = ident.sym.to_string(); + let start = ident.span.lo.0; + self.resolve_ident(&name, start); + } + + fn visit_import_decl(&mut self, _import: &ImportDecl) { + // Don't recurse — import identifiers are declarations + } + + fn visit_var_decl(&mut self, var_decl: &VarDecl) { + // Only visit initializers, not patterns (which are declarations) + for declarator in &var_decl.decls { + if let Some(init) = &declarator.init { + init.visit_with(self); + } + } + } + + fn visit_fn_decl(&mut self, fn_decl: &FnDecl) { + // Don't resolve the function name — it's a declaration + self.visit_function_inner(&fn_decl.function); + } + + fn visit_export_default_decl(&mut self, decl: &ExportDefaultDecl) { + // Mirror the collector: handle exported functions/classes with their own + // scope logic, rather than falling through to the default FnExpr visitor. + match &decl.decl { + DefaultDecl::Fn(fn_expr) => { + // Don't resolve the function name — it's a declaration + self.visit_function_inner(&fn_expr.function); + } + DefaultDecl::Class(class_expr) => { + if let Some(&scope_id) = self.find_scope_at(class_expr.class.span.lo.0) { + self.scope_stack.push(scope_id); + class_expr.class.visit_children_with(self); + self.scope_stack.pop(); + } + } + DefaultDecl::TsInterfaceDecl(d) => { + d.visit_with(self); + } + } + } + + fn visit_fn_expr(&mut self, fn_expr: &FnExpr) { + let func_start = fn_expr.function.span.lo.0; + if let Some(&scope_id) = self.find_scope_at(func_start) { + self.scope_stack.push(scope_id); + + // Don't resolve named fn expr ident — it's a declaration + + for param in &fn_expr.function.params { + self.visit_param_pattern(¶m.pat); + } + + if let Some(body) = &fn_expr.function.body { + self.function_body_spans.insert(body.span.lo.0); + body.visit_with(self); + } + + self.scope_stack.pop(); + } + } + + fn visit_arrow_expr(&mut self, arrow: &ArrowExpr) { + let func_start = arrow.span.lo.0; + if let Some(&scope_id) = self.find_scope_at(func_start) { + self.scope_stack.push(scope_id); + + for param in &arrow.params { + self.visit_param_pattern(param); + } + + match &*arrow.body { + BlockStmtOrExpr::BlockStmt(block) => { + self.function_body_spans.insert(block.span.lo.0); + block.visit_with(self); + } + BlockStmtOrExpr::Expr(expr) => { + expr.visit_with(self); + } + } + + self.scope_stack.pop(); + } + } + + fn visit_block_stmt(&mut self, block: &BlockStmt) { + if self.function_body_spans.remove(&block.span.lo.0) { + // Function/catch body — scope already pushed + block.visit_children_with(self); + } else if let Some(&scope_id) = self.find_scope_at(block.span.lo.0) { + self.scope_stack.push(scope_id); + block.visit_children_with(self); + self.scope_stack.pop(); + } else { + block.visit_children_with(self); + } + } + + fn visit_for_stmt(&mut self, for_stmt: &ForStmt) { + if let Some(&scope_id) = self.find_scope_at(for_stmt.span.lo.0) { + self.scope_stack.push(scope_id); + + if let Some(init) = &for_stmt.init { + init.visit_with(self); + } + if let Some(test) = &for_stmt.test { + test.visit_with(self); + } + if let Some(update) = &for_stmt.update { + update.visit_with(self); + } + for_stmt.body.visit_with(self); + + self.scope_stack.pop(); + } + } + + fn visit_for_in_stmt(&mut self, for_in: &ForInStmt) { + if let Some(&scope_id) = self.find_scope_at(for_in.span.lo.0) { + self.scope_stack.push(scope_id); + for_in.left.visit_with(self); + for_in.right.visit_with(self); + for_in.body.visit_with(self); + self.scope_stack.pop(); + } + } + + fn visit_for_of_stmt(&mut self, for_of: &ForOfStmt) { + if let Some(&scope_id) = self.find_scope_at(for_of.span.lo.0) { + self.scope_stack.push(scope_id); + for_of.left.visit_with(self); + for_of.right.visit_with(self); + for_of.body.visit_with(self); + self.scope_stack.pop(); + } + } + + fn visit_catch_clause(&mut self, catch: &CatchClause) { + if let Some(&scope_id) = self.find_scope_at(catch.span.lo.0) { + self.scope_stack.push(scope_id); + // Don't visit catch param — it's a declaration + self.function_body_spans.insert(catch.body.span.lo.0); + catch.body.visit_with(self); + self.scope_stack.pop(); + } + } + + fn visit_switch_stmt(&mut self, switch: &SwitchStmt) { + switch.discriminant.visit_with(self); + + if let Some(&scope_id) = self.find_scope_at(switch.span.lo.0) { + self.scope_stack.push(scope_id); + for case in &switch.cases { + case.visit_with(self); + } + self.scope_stack.pop(); + } + } + + fn visit_class_decl(&mut self, class_decl: &ClassDecl) { + // Don't resolve the class name — it's a declaration + if let Some(&scope_id) = self.find_scope_at(class_decl.class.span.lo.0) { + self.scope_stack.push(scope_id); + class_decl.class.visit_children_with(self); + self.scope_stack.pop(); + } + } + + fn visit_class_expr(&mut self, class_expr: &ClassExpr) { + if let Some(&scope_id) = self.find_scope_at(class_expr.class.span.lo.0) { + self.scope_stack.push(scope_id); + // Don't resolve named class expr ident — it's a declaration + class_expr.class.visit_children_with(self); + self.scope_stack.pop(); + } + } + + fn visit_function(&mut self, f: &Function) { + // Reached for object/class methods via default traversal + self.visit_function_inner(f); + } + + // Don't resolve property idents on member expressions as references + fn visit_member_expr(&mut self, member: &MemberExpr) { + member.obj.visit_with(self); + if let MemberProp::Computed(computed) = &member.prop { + computed.visit_with(self); + } + } + + // Handle property definitions — don't resolve non-computed keys + fn visit_prop(&mut self, prop: &Prop) { + match prop { + Prop::Shorthand(ident) => { + // Shorthand property `{ x }` — `x` is a reference + self.visit_ident(ident); + } + Prop::KeyValue(kv) => { + if let PropName::Computed(computed) = &kv.key { + computed.visit_with(self); + } + kv.value.visit_with(self); + } + Prop::Assign(assign) => { + assign.value.visit_with(self); + } + Prop::Getter(getter) => { + if let PropName::Computed(computed) = &getter.key { + computed.visit_with(self); + } + if let Some(body) = &getter.body { + body.visit_with(self); + } + } + Prop::Setter(setter) => { + if let PropName::Computed(computed) = &setter.key { + computed.visit_with(self); + } + setter.param.visit_with(self); + if let Some(body) = &setter.body { + body.visit_with(self); + } + } + Prop::Method(method) => { + if let PropName::Computed(computed) = &method.key { + computed.visit_with(self); + } + method.function.visit_with(self); + } + } + } + + // Don't resolve labels + fn visit_labeled_stmt(&mut self, labeled: &LabeledStmt) { + labeled.body.visit_with(self); + } + + fn visit_break_stmt(&mut self, _break_stmt: &BreakStmt) {} + + fn visit_continue_stmt(&mut self, _continue_stmt: &ContinueStmt) {} +} diff --git a/crates/swc_ecma_react_compiler/src/diagnostics.rs b/crates/swc_ecma_react_compiler/src/diagnostics.rs new file mode 100644 index 000000000000..9a038f600f4a --- /dev/null +++ b/crates/swc_ecma_react_compiler/src/diagnostics.rs @@ -0,0 +1,109 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use react_compiler::entrypoint::compile_result::{ + CompileResult, CompilerErrorDetailInfo, CompilerErrorInfo, LoggerEvent, +}; + +#[derive(Debug, Clone)] +pub enum Severity { + Error, + Warning, +} + +#[derive(Debug, Clone)] +pub struct DiagnosticMessage { + pub severity: Severity, + pub message: String, + pub span: Option<(u32, u32)>, +} + +/// Converts a CompileResult into diagnostic messages for display +pub fn compile_result_to_diagnostics(result: &CompileResult) -> Vec { + let mut diagnostics = Vec::new(); + + match result { + CompileResult::Success { events, .. } => { + // Process logger events from successful compilation + for event in events { + if let Some(diag) = event_to_diagnostic(event) { + diagnostics.push(diag); + } + } + } + CompileResult::Error { error, events, .. } => { + // Add the main error + diagnostics.push(error_info_to_diagnostic(error)); + + // Process logger events from failed compilation + for event in events { + if let Some(diag) = event_to_diagnostic(event) { + diagnostics.push(diag); + } + } + } + } + + diagnostics +} + +fn error_info_to_diagnostic(error: &CompilerErrorInfo) -> DiagnosticMessage { + let message = if let Some(description) = &error.description { + format!("[ReactCompiler] {}. {}", error.reason, description) + } else { + format!("[ReactCompiler] {}", error.reason) + }; + + DiagnosticMessage { + severity: Severity::Error, + message, + span: None, + } +} + +fn error_detail_to_diagnostic( + detail: &CompilerErrorDetailInfo, + is_error: bool, +) -> DiagnosticMessage { + let message = if let Some(description) = &detail.description { + format!( + "[ReactCompiler] {}: {}. {}", + detail.category, detail.reason, description + ) + } else { + format!("[ReactCompiler] {}: {}", detail.category, detail.reason) + }; + + DiagnosticMessage { + severity: if is_error { + Severity::Error + } else { + Severity::Warning + }, + message, + span: None, + } +} + +fn event_to_diagnostic(event: &LoggerEvent) -> Option { + match event { + LoggerEvent::CompileSuccess { .. } => None, + LoggerEvent::CompileSkip { .. } => None, + LoggerEvent::CompileError { detail, .. } + | LoggerEvent::CompileErrorWithLoc { detail, .. } => { + Some(error_detail_to_diagnostic(detail, false)) + } + LoggerEvent::CompileUnexpectedThrow { data, .. } => Some(DiagnosticMessage { + severity: Severity::Error, + message: format!("[ReactCompiler] Unexpected error: {data}"), + span: None, + }), + LoggerEvent::PipelineError { data, .. } => Some(DiagnosticMessage { + severity: Severity::Error, + message: format!("[ReactCompiler] Pipeline error: {data}"), + span: None, + }), + } +} diff --git a/crates/swc_ecma_react_compiler/src/fast_check.rs b/crates/swc_ecma_react_compiler/src/fast_check.rs index 31b59d25a9cf..a14215908472 100644 --- a/crates/swc_ecma_react_compiler/src/fast_check.rs +++ b/crates/swc_ecma_react_compiler/src/fast_check.rs @@ -1,125 +1,16 @@ -use swc_ecma_ast::{ - Callee, ExportDefaultDecl, ExportDefaultExpr, Expr, FnDecl, FnExpr, Pat, Program, Stmt, - VarDeclarator, -}; -use swc_ecma_visit::{Visit, VisitWith}; -pub fn is_required(program: &Program) -> bool { - let mut finder = Finder::default(); - finder.visit_program(program); - finder.found -} - -#[derive(Default)] -struct Finder { - found: bool, - - /// We are in a function that starts with a capital letter or it's a - /// function that starts with `use` - is_interested: bool, -} - -impl Visit for Finder { - fn visit_callee(&mut self, node: &Callee) { - if self.is_interested { - if let Callee::Expr(e) = node { - if let Expr::Ident(c) = &**e { - if c.sym.starts_with("use") { - self.found = true; - return; - } - } - } - } - - node.visit_children_with(self); - } - - fn visit_export_default_decl(&mut self, node: &ExportDefaultDecl) { - let old = self.is_interested; - - self.is_interested = true; +use swc_ecma_ast::{Module, ModuleItem, Program}; - node.visit_children_with(self); - - self.is_interested = old; - } - - fn visit_export_default_expr(&mut self, node: &ExportDefaultExpr) { - let old = self.is_interested; - - self.is_interested = true; - - node.visit_children_with(self); - - self.is_interested = old; - } - - fn visit_expr(&mut self, node: &Expr) { - if self.found { - return; - } - if self.is_interested - && matches!( - node, - Expr::JSXMember(..) - | Expr::JSXNamespacedName(..) - | Expr::JSXEmpty(..) - | Expr::JSXElement(..) - | Expr::JSXFragment(..) - ) - { - self.found = true; - return; - } - - node.visit_children_with(self); - } - - fn visit_fn_decl(&mut self, node: &FnDecl) { - let old = self.is_interested; - - self.is_interested = node.ident.sym.starts_with("use") - || node.ident.sym.starts_with(|c: char| c.is_ascii_uppercase()); - - node.visit_children_with(self); - - self.is_interested = old; - } - - fn visit_fn_expr(&mut self, node: &FnExpr) { - let old = self.is_interested; - - self.is_interested |= node.ident.as_ref().is_some_and(|ident| { - ident.sym.starts_with("use") || ident.sym.starts_with(|c: char| c.is_ascii_uppercase()) - }); - - node.visit_children_with(self); - - self.is_interested = old; - } - - fn visit_stmt(&mut self, node: &Stmt) { - if self.found { - return; - } - node.visit_children_with(self); - } - - fn visit_var_declarator(&mut self, node: &VarDeclarator) { - let old = self.is_interested; - - if matches!(node.init.as_deref(), Some(Expr::Fn(..) | Expr::Arrow(..))) { - if let Pat::Ident(ident) = &node.name { - self.is_interested = ident.sym.starts_with("use") - || ident.sym.starts_with(|c: char| c.is_ascii_uppercase()); - } else { - self.is_interested = false; - } +pub fn is_required(program: &Program) -> bool { + match program { + Program::Module(module) => crate::prefilter::has_react_like_functions(module), + Program::Script(script) => { + let module = Module { + span: script.span, + body: script.body.iter().cloned().map(ModuleItem::Stmt).collect(), + shebang: script.shebang.clone(), + }; + crate::prefilter::has_react_like_functions(&module) } - - node.visit_children_with(self); - - self.is_interested = old; } } diff --git a/crates/swc_ecma_react_compiler/src/lib.rs b/crates/swc_ecma_react_compiler/src/lib.rs index 5875c5ced97b..863dc831b5da 100644 --- a/crates/swc_ecma_react_compiler/src/lib.rs +++ b/crates/swc_ecma_react_compiler/src/lib.rs @@ -1,3 +1,1314 @@ #![deny(clippy::all)] +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. +//! Vendored SWC adapter for the Rust React Compiler from +//! `facebook/react#36173` at commit `72adadf3097cd60e1f320d609e53f19eaf3e81cc`. + +pub mod convert_ast; +pub mod convert_ast_reverse; +pub mod convert_scope; +pub mod diagnostics; pub mod fast_check; +pub mod prefilter; + +use std::cell::RefCell; + +use convert_ast::convert_module_with_source_type; +use convert_ast_reverse::convert_program_to_swc_with_source; +use convert_scope::build_scope_info; +use diagnostics::{compile_result_to_diagnostics, DiagnosticMessage}; +use prefilter::has_react_like_functions; +use react_compiler::entrypoint::{compile_result::LoggerEvent, plugin_options::PluginOptions}; +use swc_common::comments::Comments; + +/// Describes where a blank line should be inserted relative to a body item. +#[derive(Clone, Debug)] +pub enum BlankLinePosition { + /// Insert blank line before the item (including its leading comments). + /// The `first_code_line` is the item's first code line (without comments) + /// used as a search anchor in the output. + BeforeItem { first_code_line: String }, + /// Insert blank line between the item's leading comments and its code. + /// The `first_code_line` is used to find where the code starts. + BeforeCode { first_code_line: String }, +} + +thread_local! { + /// Thread-local storage for comments from the last compilation. + /// Used by `emit` to include comments without API changes. + static LAST_COMMENTS: RefCell> = const { RefCell::new(None) }; + + /// Thread-local storage for blank line positions. + /// Contains information about where to insert blank lines during emit. + static BLANK_LINE_POSITIONS: RefCell> = const { RefCell::new(Vec::new()) }; +} + +/// Result of compiling a program via the SWC frontend. +pub struct TransformResult { + /// The compiled program as an SWC Module (None if no changes needed). + pub module: Option, + /// Comments extracted from the compiled AST (for use with + /// `emit_with_comments`). + pub comments: Option, + pub diagnostics: Vec, + pub events: Vec, +} + +/// Source syntax accepted by the high-level source APIs. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] +pub enum SourceSyntax { + EcmaScript, + #[default] + TypeScript, +} + +/// Parser configuration for the high-level source APIs. +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct SourceParser { + pub syntax: SourceSyntax, + pub jsx: bool, + pub tsx: bool, + pub decorators: bool, +} + +impl Default for SourceParser { + fn default() -> Self { + Self { + syntax: SourceSyntax::TypeScript, + jsx: false, + tsx: true, + decorators: true, + } + } +} + +/// High-level emitted output used by the node binding and package wrapper. +pub struct SourceTransformOutput { + pub code: String, + pub diagnostics: Vec, + pub map: Option, +} + +/// Result of linting a program via the SWC frontend. +pub struct LintResult { + pub diagnostics: Vec, +} + +fn parse_module_from_source( + source_text: &str, + parser: SourceParser, +) -> Result { + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let fm = cm.new_source_file( + swc_common::sync::Lrc::new(swc_common::FileName::Anon), + source_text.to_string(), + ); + + let syntax = match parser.syntax { + SourceSyntax::EcmaScript => swc_ecma_parser::Syntax::Es(swc_ecma_parser::EsSyntax { + jsx: parser.jsx, + decorators: parser.decorators, + ..Default::default() + }), + SourceSyntax::TypeScript => { + swc_ecma_parser::Syntax::Typescript(swc_ecma_parser::TsSyntax { + tsx: parser.tsx, + decorators: parser.decorators, + ..Default::default() + }) + } + }; + + let mut errors = vec![]; + let module = swc_ecma_parser::parse_file_as_module( + &fm, + syntax, + swc_ecma_ast::EsVersion::latest(), + None, + &mut errors, + ) + .map_err(|_| String::from("Syntax Error"))?; + + if errors.is_empty() { + Ok(module) + } else { + Err(String::from("Syntax Error")) + } +} + +/// Parses source text with an explicit parser configuration, then runs the +/// React-compiler prefilter. +pub fn is_required_source_with_parser( + source_text: &str, + parser: SourceParser, +) -> Result { + let module = parse_module_from_source(source_text, parser)?; + Ok(prefilter::has_react_like_functions(&module)) +} + +/// Primary transform API — accepts pre-parsed SWC Module. +pub fn transform( + module: &swc_ecma_ast::Module, + source_text: &str, + options: PluginOptions, +) -> TransformResult { + if options.compilation_mode != "all" && !has_react_like_functions(module) { + return TransformResult { + module: None, + comments: None, + diagnostics: vec![], + events: vec![], + }; + } + + // Detect source type from pragma. The @script pragma indicates + // CommonJS (script) mode, which affects how imports are emitted. + let source_type = if source_text + .lines() + .next() + .is_some_and(|line| line.contains("@script")) + { + react_compiler_ast::SourceType::Script + } else { + react_compiler_ast::SourceType::Module + }; + let file = convert_module_with_source_type(module, source_text, source_type); + let scope_info = build_scope_info(module); + let result = react_compiler::entrypoint::program::compile_program(file, scope_info, options); + + let diagnostics = compile_result_to_diagnostics(&result); + let (program_json, events) = match result { + react_compiler::entrypoint::compile_result::CompileResult::Success { + ast, events, .. + } => (ast, events), + react_compiler::entrypoint::compile_result::CompileResult::Error { events, .. } => { + (None, events) + } + }; + + let conversion_result = program_json.and_then(|raw_json| { + // First parse to serde_json::Value which deduplicates "type" fields + // (the compiler output can produce duplicate "type" keys due to + // BaseNode.node_type + #[serde(tag = "type")] enum tagging) + let value: serde_json::Value = serde_json::from_str(raw_json.get()).ok()?; + let file: react_compiler_ast::File = serde_json::from_value(value).ok()?; + let result = convert_program_to_swc_with_source(&file, Some(source_text)); + Some(result) + }); + + let (mut swc_module, mut comments) = match conversion_result { + Some(result) => (Some(result.module), Some(result.comments)), + None => (None, None), + }; + + // If we have a compiled module, extract comments from the original source + // and merge them into the comment map. The Rust compiler does not preserve + // comments in its output, so we re-extract them from the source text. + if let Some(ref mut swc_mod) = swc_module { + use swc_common::Spanned; + + // Compute blank line positions BEFORE span fixup, while spans still + // reflect original source positions. Babel's generator adds blank + // lines between consecutive items when the original source had blank + // lines between them (i.e., endLine(prev) + 1 < startLine(next)). + let blank_line_positions = compute_blank_line_positions(&swc_mod.body, source_text); + + // Fix up dummy spans on compiler-generated items: SWC codegen skips + // comments at BytePos(0) (DUMMY), so we give generated items a real + // span derived from the original module's first item. + let has_source_items = !module.body.is_empty(); + if has_source_items { + // Use a synthetic span at position 1 (minimal non-dummy position) + // This ensures comments can be attached to the first item. + let synthetic_span = + swc_common::Span::new(swc_common::BytePos(1), swc_common::BytePos(1)); + for item in &mut swc_mod.body { + if item.span().lo.is_dummy() { + match item { + swc_ecma_ast::ModuleItem::ModuleDecl(swc_ecma_ast::ModuleDecl::Import( + import, + )) => { + import.span = synthetic_span; + } + swc_ecma_ast::ModuleItem::Stmt(swc_ecma_ast::Stmt::Decl( + swc_ecma_ast::Decl::Var(var), + )) => { + var.span = synthetic_span; + } + _ => {} + } + } + } + } + + let source_comments = extract_source_comments(source_text); + if !source_comments.is_empty() { + let merged = comments.unwrap_or_default(); + + for (orig_pos, comment_list) in source_comments { + // Keep comments at their original positions. Comments + // attached to the first source statement will appear before + // the corresponding statement in the compiled output + // (which preserves the original import's span). + merged.add_leading_comments(orig_pos, comment_list); + } + comments = Some(merged); + } + + // Store blank line positions in thread-local for `emit` to use + BLANK_LINE_POSITIONS.with(|cell| { + *cell.borrow_mut() = blank_line_positions; + }); + } + + // Store comments in thread-local for `emit` to use + LAST_COMMENTS.with(|cell| { + *cell.borrow_mut() = comments.clone(); + }); + + TransformResult { + module: swc_module, + comments, + diagnostics, + events, + } +} + +/// Convenience wrapper — parses source text, then transforms. +pub fn transform_source(source_text: &str, options: PluginOptions) -> TransformResult { + match try_transform_source_with_parser(source_text, options, SourceParser::default()) { + Ok(result) => result, + Err(_) => TransformResult { + module: None, + comments: None, + diagnostics: vec![], + events: vec![], + }, + } +} + +/// Parses source text with an explicit parser configuration, then transforms. +pub fn try_transform_source_with_parser( + source_text: &str, + options: PluginOptions, + parser: SourceParser, +) -> Result { + let module = parse_module_from_source(source_text, parser)?; + Ok(transform(&module, source_text, options)) +} + +/// Parses source text, transforms it, and emits code. +pub fn try_transform_source_to_code_with_parser( + source_text: &str, + options: PluginOptions, + parser: SourceParser, +) -> Result { + let result = try_transform_source_with_parser(source_text, options, parser)?; + let code = match result.module { + Some(ref module) => emit(module), + None => normalize_source(source_text), + }; + + Ok(SourceTransformOutput { + code, + diagnostics: result + .diagnostics + .into_iter() + .map(|diagnostic| diagnostic.message) + .collect(), + map: None, + }) +} + +/// Lint API — same as transform but only collects diagnostics, no AST output. +pub fn lint( + module: &swc_ecma_ast::Module, + source_text: &str, + options: PluginOptions, +) -> LintResult { + let mut opts = options; + opts.no_emit = true; + + let result = transform(module, source_text, opts); + LintResult { + diagnostics: result.diagnostics, + } +} + +/// Emit an SWC Module to a string via swc_ecma_codegen. +/// If `transform` was called on the same thread, any comments from the +/// compiled AST are automatically included. +pub fn emit(module: &swc_ecma_ast::Module) -> String { + LAST_COMMENTS.with(|cell| { + let borrowed = cell.borrow(); + let positions = BLANK_LINE_POSITIONS.with(|bl| bl.borrow().clone()); + emit_with_comments(module, borrowed.as_ref(), &positions) + }) +} + +/// Emit an SWC Module to a string, optionally including comments. +/// `blank_line_positions` describes where blank lines should be inserted +/// to match Babel's blank line behavior. +pub fn emit_with_comments( + module: &swc_ecma_ast::Module, + comments: Option<&swc_common::comments::SingleThreadedComments>, + blank_line_positions: &[BlankLinePosition], +) -> String { + // Standard emit path + let code = emit_module_to_string(module, comments); + let code = fix_block_comment_newlines(&code); + + // Add blank lines after directives to match Babel's codegen behavior. + // Babel always emits a blank line after the last directive in a + // program/function body. + let code = add_blank_lines_after_directives(&code); + + // Reposition blank lines that SWC places before comment blocks: + // SWC emits blank lines before leading comments, but Babel places + // them after the comments (between comments and the declaration). + // Move blank lines from before comment blocks to after them when + // the comment block is followed by a top-level declaration. + let code = reposition_comment_blank_lines(&code); + + // Expand single-line object literals to multi-line format in + // FIXTURE_ENTRYPOINT-style structures. SWC codegen emits small objects + // on single lines while Babel puts them on multiple lines. Prettier + // preserves this choice, causing formatting differences. + let code = expand_fixture_entrypoint_objects(&code); + + if blank_line_positions.is_empty() || module.body.is_empty() { + return code; + } + + // Insert blank lines between top-level declarations to match Babel's + // output. Babel's generator preserves blank lines from the original + // source between consecutive top-level items. + insert_blank_lines_in_output(&code, blank_line_positions) +} + +/// Emit a full module to a string. +fn emit_module_to_string( + module: &swc_ecma_ast::Module, + comments: Option<&swc_common::comments::SingleThreadedComments>, +) -> String { + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let mut buf = vec![]; + { + let wr = swc_ecma_codegen::text_writer::JsWriter::new(cm.clone(), "\n", &mut buf, None); + let mut emitter = swc_ecma_codegen::Emitter { + cfg: swc_ecma_codegen::Config::default().with_minify(false), + cm, + comments: comments.map(|c| c as &dyn swc_common::comments::Comments), + wr: Box::new(wr), + }; + swc_ecma_codegen::Node::emit_with(module, &mut emitter).unwrap(); + } + String::from_utf8(buf).unwrap() +} + +/// Insert blank lines into the emitted output at positions specified by +/// `blank_line_positions`. Each position includes a `first_code_line` that +/// identifies the item's first line of code (without comments), used as +/// a search anchor in the output. +fn insert_blank_lines_in_output(code: &str, positions: &[BlankLinePosition]) -> String { + if positions.is_empty() { + return code.to_string(); + } + + let lines: Vec<&str> = code.lines().collect(); + + // Phase 1: Find which output line indices need a blank line inserted + // BEFORE them. We do this by finding each target's first_code_line in + // the output, then computing the actual insert line. + let mut insert_before: Vec = Vec::new(); + let mut used_lines: Vec = vec![false; lines.len()]; + + for pos in positions { + let (first_code_line, before_comments) = match pos { + BlankLinePosition::BeforeItem { first_code_line } => (first_code_line.as_str(), true), + BlankLinePosition::BeforeCode { first_code_line } => (first_code_line.as_str(), false), + }; + + // Find this code line in the output (first unused match). + // For BeforeCode positions, also allow matching already-used lines + // since BeforeItem and BeforeCode may target the same code line. + let mut found_idx = None; + for (i, &line) in lines.iter().enumerate() { + if line == first_code_line && (!used_lines[i] || !before_comments) { + found_idx = Some(i); + if !used_lines[i] { + used_lines[i] = true; + } + break; + } + } + + let code_line_idx = match found_idx { + Some(idx) => idx, + None => continue, + }; + + let insert_line = if before_comments { + // BeforeItem: insert before the comment block that precedes + // this code line + find_comment_block_start(&lines, code_line_idx) + } else { + // BeforeCode: insert right before the code line itself + code_line_idx + }; + + // Only insert if the previous line is not already blank + if insert_line > 0 && !lines[insert_line - 1].trim().is_empty() { + insert_before.push(insert_line); + } + } + + if insert_before.is_empty() { + return code.to_string(); + } + + insert_before.sort_unstable(); + insert_before.dedup(); + + // Phase 2: Build the result with blank lines inserted + let mut result = String::with_capacity(code.len() + insert_before.len() * 2); + let mut insert_idx = 0; + + for (line_idx, &line) in lines.iter().enumerate() { + // Check if we need to insert a blank line before this line + if insert_idx < insert_before.len() && insert_before[insert_idx] == line_idx { + result.push('\n'); + insert_idx += 1; + } + + result.push_str(line); + if line_idx < lines.len() - 1 || code.ends_with('\n') { + result.push('\n'); + } + } + + result +} + +/// Find the start of a comment block that precedes the line at `code_line_idx`. +/// Walks backwards from `code_line_idx - 1` as long as lines are comment +/// lines (starting with `//`, `/*`, ` *`, `*/`, or `/**`). +fn find_comment_block_start(lines: &[&str], code_line_idx: usize) -> usize { + let mut start = code_line_idx; + let mut i = code_line_idx; + while i > 0 { + i -= 1; + let trimmed = lines[i].trim(); + if trimmed.is_empty() { + break; // blank line, stop + } + if trimmed.starts_with("//") + || trimmed.starts_with("/*") + || trimmed.starts_with("* ") + || trimmed.starts_with("*/") + || trimmed == "*" + { + start = i; + } else { + break; + } + } + start +} + +/// Add blank lines after directive sequences in function/program bodies. +/// +/// Babel's codegen emits a blank line after the last directive in a body +/// (e.g., after `"use strict";` or `"use no memo";`). SWC's codegen +/// does not. This function adds those blank lines to match Babel's output. +fn add_blank_lines_after_directives(code: &str) -> String { + let lines: Vec<&str> = code.lines().collect(); + if lines.is_empty() { + return code.to_string(); + } + + let mut result: Vec<&str> = Vec::with_capacity(lines.len() + 8); + let mut i = 0; + + while i < lines.len() { + result.push(lines[i]); + + // Check if this line is a directive (string literal expression statement) + if is_directive_line(lines[i]) { + // Check if the next line is NOT a directive and NOT blank + if i + 1 < lines.len() + && !is_directive_line(lines[i + 1]) + && !lines[i + 1].trim().is_empty() + { + result.push(""); + } + } + + i += 1; + } + + // Rejoin, preserving trailing newline if present + let mut output = result.join("\n"); + if code.ends_with('\n') && !output.ends_with('\n') { + output.push('\n'); + } + output +} + +/// Check if a line is a directive (a string literal expression statement). +/// Directives look like: `"use strict";` or `'use no memo';` possibly with +/// leading whitespace (indentation for function body directives). +fn is_directive_line(line: &str) -> bool { + let trimmed = line.trim(); + // Must start with a quote and end with the matching quote + semicolon + if let Some(rest) = trimmed.strip_prefix('"') { + rest.ends_with("\";") + } else if let Some(rest) = trimmed.strip_prefix('\'') { + rest.ends_with("';") + } else { + false + } +} + +/// Insert newlines after `*/` when followed by code on the same line. +/// Only applies to multiline block comments (JSDoc-style), not inline ones. +fn fix_block_comment_newlines(code: &str) -> String { + let mut result = String::with_capacity(code.len()); + let mut chars = code.char_indices().peekable(); + let bytes = code.as_bytes(); + let mut in_block_comment = false; + let mut block_comment_multiline = false; + + while let Some((i, c)) = chars.next() { + // Track block comment state + if !in_block_comment && c == '/' && bytes.get(i + 1) == Some(&b'*') { + in_block_comment = true; + block_comment_multiline = false; + result.push(c); + continue; + } + + if in_block_comment { + if c == '\n' { + block_comment_multiline = true; + } + result.push(c); + + // Check for end of block comment + if c == '*' && bytes.get(i + 1) == Some(&b'/') { + chars.next(); + result.push('/'); + in_block_comment = false; + + if block_comment_multiline { + // Skip spaces after `*/` + let mut spaces = String::new(); + while let Some(&(_, next_c)) = chars.peek() { + if next_c == ' ' || next_c == '\t' { + spaces.push(next_c); + chars.next(); + } else { + break; + } + } + + // If followed by code on the same line, insert newline + if let Some(&(_, next_c)) = chars.peek() { + if next_c != '\n' && next_c != '\r' { + result.push('\n'); + } else { + result.push_str(&spaces); + } + } else { + result.push_str(&spaces); + } + } + } + continue; + } + + result.push(c); + } + result +} + +/// Reposition blank lines from before comment blocks to after them. +/// +/// SWC's codegen sometimes places blank lines before leading comment blocks, +/// but Babel's generator places them after the comments (between the comment +/// block and the declaration). This function detects the pattern: +/// +/// +/// +/// +/// +/// +/// And transforms it to: +/// +/// +/// +/// +/// +/// +/// This only applies to top-level (non-indented) comment blocks. +fn reposition_comment_blank_lines(code: &str) -> String { + let lines: Vec<&str> = code.lines().collect(); + if lines.len() < 3 { + return code.to_string(); + } + + let mut result: Vec<&str> = Vec::with_capacity(lines.len()); + let mut i = 0; + + while i < lines.len() { + // Look for pattern: blank line followed by comment block followed by + // declaration + if lines[i].trim().is_empty() && i + 1 < lines.len() { + let comment_start = i + 1; + let first_comment = lines[comment_start].trim(); + + // Check if the next line is a top-level comment (not indented) + let is_top_level_comment = (first_comment.starts_with("//") + || first_comment.starts_with("/*") + || first_comment.starts_with("/**")) + && !lines[comment_start].starts_with(' ') + && !lines[comment_start].starts_with('\t'); + + if is_top_level_comment { + // Find the end of the comment block + let mut comment_end = comment_start; + while comment_end < lines.len() { + let trimmed = lines[comment_end].trim(); + if trimmed.starts_with("//") + || trimmed.starts_with("/*") + || trimmed.starts_with("* ") + || trimmed.starts_with("*/") + || trimmed == "*" + || trimmed.starts_with("/**") + { + comment_end += 1; + } else { + break; + } + } + + // Check if the line after the comment block is a top-level + // declaration (function, class, export, const, let, var). + // This is specifically for Babel's codegen which places blank + // lines after comment blocks before declarations, not before. + if comment_end < lines.len() && comment_end > comment_start { + let after_comment = lines[comment_end].trim(); + let is_declaration = after_comment.starts_with("function ") + || after_comment.starts_with("export ") + || after_comment.starts_with("class ") + || after_comment.starts_with("const ") + || after_comment.starts_with("let ") + || after_comment.starts_with("var ") + || after_comment.starts_with("import ") + || after_comment.starts_with("async function ") + || after_comment.starts_with("async function*"); + + if is_declaration { + // Also check that the line before the blank line is + // non-empty (end of import or end of function) + let prev_non_empty = i > 0 && !lines[i - 1].trim().is_empty(); + + if prev_non_empty { + // Move the blank line: emit comment block first, + // then blank line, then continue + for line in lines.iter().take(comment_end).skip(comment_start) { + result.push(*line); + } + result.push(""); // blank line after comments + i = comment_end; + continue; + } + } + } + } + } + + result.push(lines[i]); + i += 1; + } + + // Rejoin, preserving trailing newline if present + let mut output = result.join("\n"); + if code.ends_with('\n') && !output.ends_with('\n') { + output.push('\n'); + } + output +} + +/// Compute where blank lines should be inserted in the emitted output. +/// +/// This replicates Babel's `@babel/generator` behavior: when consecutive +/// top-level items had blank lines between them in the original source, +/// the generator preserves those blank lines. +/// +/// We check the item spans (byte positions into the original source) and +/// determine if there was a blank line gap between consecutive items. +/// We also determine WHERE the blank line should go: before the item's +/// leading comments (BeforeItem) or between the comments and code (BeforeCode). +fn compute_blank_line_positions( + body: &[swc_ecma_ast::ModuleItem], + source_text: &str, +) -> Vec { + use swc_common::Spanned; + + let mut result = Vec::new(); + + // Check for blank lines between leading comments and the first + // non-DUMMY item. This handles the case where comments from the + // source (e.g., pragma comments) are attached as leading comments + // to an import, with a blank line gap in the original source. + for item in body { + let lo = item.span().lo; + if lo.is_dummy() { + continue; + } + let lo_u = (lo.0 as usize).saturating_sub(1); + if lo_u > source_text.len() || lo_u == 0 { + break; + } + // Check the source text before this item for comments followed by blank lines + let before = &source_text[..lo_u]; + if has_blank_line(before) && (before.contains("//") || before.contains("/*")) { + // There are comments and blank lines before this item. + // Check if the blank line is between the comments and this item + // (i.e., "BeforeCode" pattern) + if !is_blank_line_before_comments(before) { + let first_code_line = get_first_code_line(item); + result.push(BlankLinePosition::BeforeCode { first_code_line }); + } + } + break; // Only check the first non-DUMMY item + } + + for i in 1..body.len() { + let prev = &body[i - 1]; + let curr = &body[i]; + + let prev_hi = prev.span().hi; + let curr_lo = curr.span().lo; + + // Skip items with dummy/synthetic spans (BytePos(0)) + if prev_hi.is_dummy() || curr_lo.is_dummy() { + continue; + } + + // SWC BytePos is 1-based (BytePos(0) is DUMMY/reserved). Convert + // to 0-based source text indices by subtracting 1. + let prev_hi_u = (prev_hi.0 as usize).saturating_sub(1); + let curr_lo_u = (curr_lo.0 as usize).saturating_sub(1); + + if prev_hi_u >= curr_lo_u || prev_hi_u > source_text.len() || curr_lo_u > source_text.len() + { + continue; + } + + // Check the text between the two items for blank lines. + // Babel's generator preserves blank lines from the original source + // between consecutive top-level items. + let between = &source_text[prev_hi_u..curr_lo_u]; + if !has_blank_line(between) { + continue; + } + + // Only preserve blank lines when there are comments between the + // items. This matches Babel's behavior: the TS compiler's + // replaceWith() creates fresh nodes without position info, so + // Babel's generator only sees position gaps when comments with + // original positions are present between items. Without comments, + // the generated code and the next item end up close together, + // so Babel sees no gap and doesn't insert a blank line. + if !between.contains("//") && !between.contains("/*") { + continue; + } + + // Determine the first code line of the current item (emitted + // without comments) for use as a search anchor. + let first_code_line = get_first_code_line(curr); + + // Determine whether blank lines exist before and/or after comments. + let (blank_before, blank_after) = blank_line_positions_around_comments(between); + + if blank_before && blank_after { + // Both: add blank lines before AND after comments + result.push(BlankLinePosition::BeforeItem { + first_code_line: first_code_line.clone(), + }); + result.push(BlankLinePosition::BeforeCode { first_code_line }); + } else if blank_after { + result.push(BlankLinePosition::BeforeCode { first_code_line }); + } else { + // blank_before only, or no specific position → default to BeforeItem + result.push(BlankLinePosition::BeforeItem { first_code_line }); + } + } + + result +} + +/// Check if a string contains a blank line (two consecutive newlines +/// with only whitespace between them). +fn has_blank_line(s: &str) -> bool { + let mut prev_newline = false; + for c in s.chars() { + if c == '\n' { + if prev_newline { + return true; + } + prev_newline = true; + } else if c == ' ' || c == '\t' || c == '\r' { + // whitespace between newlines is ok + } else { + prev_newline = false; + } + } + false +} + +/// Determine where blank lines exist relative to comments in the between-text. +/// +/// Returns (blank_before_comments, blank_after_comments): +/// - blank_before: there's a blank line before any comment content +/// - blank_after: there's a blank line after comment content +fn blank_line_positions_around_comments(between: &str) -> (bool, bool) { + let mut found_comment = false; + let mut prev_newline = false; + let mut blank_before = false; + let mut blank_after = false; + + for (i, c) in between.char_indices() { + if c == '\n' { + if prev_newline { + if found_comment { + blank_after = true; + } else { + blank_before = true; + } + } + prev_newline = true; + } else if c == ' ' || c == '\t' || c == '\r' { + // whitespace between newlines is ok + } else { + prev_newline = false; + if c == '/' { + let next = between.as_bytes().get(i + 1); + if next == Some(&b'*') || next == Some(&b'/') { + found_comment = true; + } + } + } + } + + (blank_before, blank_after) +} + +/// Check if the blank line in the between-text should be placed before +/// comments. Used for the first-item leading comment check. +fn is_blank_line_before_comments(between: &str) -> bool { + let (blank_before, blank_after) = blank_line_positions_around_comments(between); + // If blank lines exist after comments, prefer BeforeCode (return false) + if blank_after { + return false; + } + blank_before +} + +/// Get the first non-empty line of a ModuleItem when emitted without comments. +fn get_first_code_line(item: &swc_ecma_ast::ModuleItem) -> String { + let single_module = swc_ecma_ast::Module { + span: swc_common::DUMMY_SP, + body: vec![item.clone()], + shebang: None, + }; + + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let mut buf = vec![]; + { + let wr = swc_ecma_codegen::text_writer::JsWriter::new(cm.clone(), "\n", &mut buf, None); + let mut emitter = swc_ecma_codegen::Emitter { + cfg: swc_ecma_codegen::Config::default().with_minify(false), + cm, + comments: None, + wr: Box::new(wr), + }; + swc_ecma_codegen::Node::emit_with(&single_module, &mut emitter).unwrap(); + } + let code = String::from_utf8(buf).unwrap(); + code.lines() + .find(|l| !l.trim().is_empty()) + .unwrap_or("") + .to_string() +} + +/// Extract comments from source text using SWC's parser. +/// Returns a list of (BytePos, Vec) pairs where the BytePos is the +/// position of the token following the comment(s). +fn extract_source_comments( + source_text: &str, +) -> Vec<(swc_common::BytePos, Vec)> { + let cm = swc_common::sync::Lrc::new(swc_common::SourceMap::default()); + let fm = cm.new_source_file( + swc_common::sync::Lrc::new(swc_common::FileName::Anon), + source_text.to_string(), + ); + + let comments = swc_common::comments::SingleThreadedComments::default(); + let mut errors = vec![]; + // Try parsing as JSX+TS to handle maximum syntax variety + let _ = swc_ecma_parser::parse_file_as_module( + &fm, + swc_ecma_parser::Syntax::Typescript(swc_ecma_parser::TsSyntax { + tsx: true, + ..Default::default() + }), + swc_ecma_ast::EsVersion::latest(), + Some(&comments), + &mut errors, + ); + + // Collect all leading comments + let mut result = Vec::new(); + let (leading, _trailing) = comments.borrow_all(); + for (pos, cmts) in leading.iter() { + if !cmts.is_empty() { + result.push((*pos, cmts.clone())); + } + } + + result +} + +/// Normalize source code formatting to match Babel's codegen behavior. +/// Applied to source text that was not modified by the compiler. +/// Currently adds blank lines after directive sequences, matching +/// Babel's generator which always emits a blank line after the last +/// directive in a function/program body. +pub fn normalize_source(source: &str) -> String { + let code = add_blank_lines_after_directives(source); + let code = remove_blank_lines_after_last_import(&code); + let code = remove_blank_lines_before_fixture_entrypoint(&code); + expand_fixture_entrypoint_objects(&code) +} + +/// Remove blank lines immediately before `export const FIXTURE_ENTRYPOINT`. +/// Babel's codegen doesn't preserve blank lines between function declarations +/// and the FIXTURE_ENTRYPOINT export. +fn remove_blank_lines_before_fixture_entrypoint(code: &str) -> String { + let lines: Vec<&str> = code.lines().collect(); + if lines.is_empty() { + return code.to_string(); + } + + // Find the FIXTURE_ENTRYPOINT line + let mut entrypoint_idx: Option = None; + for (i, &line) in lines.iter().enumerate() { + if line.trim().starts_with("export const FIXTURE_ENTRYPOINT") + || line.trim().starts_with("export const FIXTURE_ENTRYPOINT") + { + entrypoint_idx = Some(i); + break; + } + } + + let entrypoint_idx = match entrypoint_idx { + Some(idx) if idx > 0 => idx, + _ => return code.to_string(), + }; + + // Check if the line before FIXTURE_ENTRYPOINT is blank + if !lines[entrypoint_idx - 1].trim().is_empty() { + return code.to_string(); + } + + // Remove the blank line + let mut result: Vec<&str> = Vec::with_capacity(lines.len()); + for (i, &line) in lines.iter().enumerate() { + if i == entrypoint_idx - 1 { + continue; + } + result.push(line); + } + + let mut output = result.join("\n"); + if code.ends_with('\n') && !output.ends_with('\n') { + output.push('\n'); + } + output +} + +/// Remove blank lines between the last import declaration and the first +/// non-import statement. Babel's codegen doesn't preserve these blank lines. +/// +/// Only removes blank lines that immediately follow the LAST import line +/// (not blank lines between comments or between import groups). +fn remove_blank_lines_after_last_import(code: &str) -> String { + let lines: Vec<&str> = code.lines().collect(); + if lines.is_empty() { + return code.to_string(); + } + + // Find the index of the last import statement + let mut last_import_idx: Option = None; + for (i, &line) in lines.iter().enumerate() { + let trimmed = line.trim(); + if trimmed.starts_with("import ") || trimmed.starts_with("import{") { + last_import_idx = Some(i); + } + } + + let last_import_idx = match last_import_idx { + Some(idx) => idx, + None => return code.to_string(), + }; + + // Check if there's a blank line immediately after the last import + let blank_idx = last_import_idx + 1; + if blank_idx >= lines.len() || !lines[blank_idx].trim().is_empty() { + return code.to_string(); + } + + // Remove this blank line + let mut result: Vec<&str> = Vec::with_capacity(lines.len()); + for (i, &line) in lines.iter().enumerate() { + if i == blank_idx { + continue; // skip the blank line + } + result.push(line); + } + + let mut output = result.join("\n"); + if code.ends_with('\n') && !output.ends_with('\n') { + output.push('\n'); + } + output +} + +/// Expand single-line object literals to multi-line format within +/// FIXTURE_ENTRYPOINT structures only. +/// +/// SWC's codegen emits small objects on a single line (e.g., +/// `params: [{ value: "test" }]`), while Babel's codegen puts them on +/// multiple lines. Since prettier preserves the single-line vs multi-line +/// choice, we need to expand them before prettier runs. +/// +/// This function ONLY operates within FIXTURE_ENTRYPOINT blocks to avoid +/// affecting compiled code. +fn expand_fixture_entrypoint_objects(code: &str) -> String { + // Find the start of FIXTURE_ENTRYPOINT block + let entrypoint_marker = "FIXTURE_ENTRYPOINT"; + if !code.contains(entrypoint_marker) { + return code.to_string(); + } + + // Find the byte position of FIXTURE_ENTRYPOINT + let entrypoint_pos = match code.find(entrypoint_marker) { + Some(pos) => pos, + None => return code.to_string(), + }; + + // Only process lines after FIXTURE_ENTRYPOINT + let (before, after) = code.split_at(entrypoint_pos); + let expanded = expand_single_line_objects_in_block(after); + format!("{before}{expanded}") +} + +fn expand_single_line_objects_in_block(code: &str) -> String { + let mut result = String::with_capacity(code.len() + 256); + let lines: Vec<&str> = code.lines().collect(); + + for (idx, &line) in lines.iter().enumerate() { + if let Some(expanded) = try_expand_object_line(line) { + result.push_str(&expanded); + } else { + result.push_str(line); + } + if idx < lines.len() - 1 || code.ends_with('\n') { + result.push('\n'); + } + } + + result +} + +/// Try to expand a single-line object literal to multi-line. +/// Returns Some(expanded) if the line contains an expandable object, None +/// otherwise. +fn try_expand_object_line(line: &str) -> Option { + let trimmed = line.trim(); + + // Calculate indentation + let indent = &line[..line.len() - line.trim_start().len()]; + + // Pattern 1: `key: [{ prop: val, prop2: val2 }],` or `key: [{ ... }, { ... }],` + // Pattern 2: `[{ prop: val }, { prop: val }]` (array of objects) + // We need to find `[` containing `{...}` entries + + // Check if this line has a [ ... ] with { ... } objects inside + if !trimmed.contains("[{") && !trimmed.contains("{ ") { + return None; + } + + // Find the bracket-enclosed array content + let bracket_start = trimmed.find('[')?; + let bracket_end = trimmed.rfind(']')?; + if bracket_start >= bracket_end { + return None; + } + + let array_content = &trimmed[bracket_start + 1..bracket_end]; + let inner_trimmed = array_content.trim(); + + // Check if this contains objects: at least one `{ ... }` + if !inner_trimmed.starts_with('{') || !inner_trimmed.contains(':') { + return None; + } + + // We need at least one property with a colon to expand + if !inner_trimmed.contains(':') { + return None; + } + + // Split the array content into individual elements + let prefix = &trimmed[..bracket_start + 1]; + let suffix = &trimmed[bracket_end..]; + + // Parse the objects - split at `}, {` boundaries + let elements = split_array_elements(inner_trimmed); + + let inner_indent = format!("{indent} "); + let prop_indent = format!("{indent} "); + + let mut result = String::new(); + result.push_str(indent); + result.push_str(prefix); + result.push('\n'); + + for (i, elem) in elements.iter().enumerate() { + let elem = elem.trim(); + if elem.starts_with('{') && elem.ends_with('}') { + // Expand this object + let obj_content = &elem[1..elem.len() - 1].trim(); + let props = split_object_properties(obj_content); + + result.push_str(&inner_indent); + result.push_str("{\n"); + for prop in props { + result.push_str(&prop_indent); + result.push_str(prop.trim()); + result.push_str(",\n"); + } + result.push_str(&inner_indent); + result.push('}'); + } else { + result.push_str(&inner_indent); + result.push_str(elem); + } + if i < elements.len() - 1 { + result.push(','); + } + result.push('\n'); + } + + result.push_str(indent); + result.push_str(suffix); + + Some(result) +} + +/// Split array content into individual elements, respecting nested +/// braces/brackets. +fn split_array_elements(s: &str) -> Vec { + let mut elements = Vec::new(); + let mut current = String::new(); + let mut depth = 0; + + for ch in s.chars() { + match ch { + '{' | '[' | '(' => { + depth += 1; + current.push(ch); + } + '}' | ']' | ')' => { + depth -= 1; + current.push(ch); + } + ',' if depth == 0 => { + let trimmed = current.trim().to_string(); + if !trimmed.is_empty() { + elements.push(trimmed); + } + current.clear(); + } + _ => { + current.push(ch); + } + } + } + let trimmed = current.trim().to_string(); + if !trimmed.is_empty() { + elements.push(trimmed); + } + elements +} + +/// Split object properties, respecting nested structures. +fn split_object_properties(s: &str) -> Vec { + let mut props = Vec::new(); + let mut current = String::new(); + let mut depth = 0; + + for ch in s.chars() { + match ch { + '{' | '[' | '(' => { + depth += 1; + current.push(ch); + } + '}' | ']' | ')' => { + depth -= 1; + current.push(ch); + } + ',' if depth == 0 => { + let trimmed = current.trim().to_string(); + if !trimmed.is_empty() { + props.push(trimmed); + } + current.clear(); + } + _ => { + current.push(ch); + } + } + } + let trimmed = current.trim().to_string(); + if !trimmed.is_empty() { + props.push(trimmed); + } + props +} + +/// Convenience wrapper — parses source text, then lints. +pub fn lint_source(source_text: &str, options: PluginOptions) -> LintResult { + match try_lint_source_with_parser(source_text, options, SourceParser::default()) { + Ok(result) => result, + Err(_) => LintResult { + diagnostics: vec![], + }, + } +} + +/// Parses source text with an explicit parser configuration, then lints. +pub fn try_lint_source_with_parser( + source_text: &str, + options: PluginOptions, + parser: SourceParser, +) -> Result { + let module = parse_module_from_source(source_text, parser)?; + Ok(lint(&module, source_text, options)) +} diff --git a/crates/swc_ecma_react_compiler/src/prefilter.rs b/crates/swc_ecma_react_compiler/src/prefilter.rs new file mode 100644 index 000000000000..7d5483ebef96 --- /dev/null +++ b/crates/swc_ecma_react_compiler/src/prefilter.rs @@ -0,0 +1,283 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use react_compiler_hir::environment::is_react_like_name; +use swc_ecma_ast::{ + ArrowExpr, AssignExpr, AssignTarget, CallExpr, Callee, Class, ExportDefaultDecl, + ExportDefaultExpr, Expr, FnDecl, FnExpr, MemberProp, Module, Pat, SimpleAssignTarget, Stmt, + VarDeclarator, +}; +use swc_ecma_visit::{Visit, VisitWith}; + +/// Checks if a module contains React-like functions (components or hooks). +/// +/// A React-like function is one whose name: +/// - Starts with an uppercase letter (component convention) +/// - Matches the pattern `use[A-Z0-9]` (hook convention) +pub fn has_react_like_functions(module: &Module) -> bool { + let mut visitor = ReactLikeVisitor::default(); + visitor.visit_module(module); + visitor.found +} + +fn is_hook_like_name(name: &str) -> bool { + name.starts_with("use") && is_react_like_name(name) +} + +#[derive(Default)] +struct ReactLikeVisitor { + found: bool, + current_name: Option, + is_interested: bool, +} + +impl Visit for ReactLikeVisitor { + fn visit_callee(&mut self, callee: &Callee) { + if self.is_interested { + if let Callee::Expr(expr) = callee { + if let Expr::Ident(ident) = &**expr { + if ident.sym.starts_with("use") { + self.found = true; + return; + } + } + } + } + + callee.visit_children_with(self); + } + + fn visit_var_declarator(&mut self, decl: &VarDeclarator) { + if self.found { + return; + } + + let name = match &decl.name { + Pat::Ident(binding_ident) => Some(binding_ident.id.sym.to_string()), + _ => None, + }; + + let prev_name = self.current_name.take(); + let prev_interested = self.is_interested; + self.current_name = name; + if matches!(decl.init.as_deref(), Some(Expr::Fn(..) | Expr::Arrow(..))) { + self.is_interested = + prev_interested || self.current_name.as_deref().is_some_and(is_react_like_name); + } + + if let Some(init) = &decl.init { + self.visit_expr(init); + } + + self.current_name = prev_name; + self.is_interested = prev_interested; + } + + fn visit_assign_expr(&mut self, expr: &AssignExpr) { + if self.found { + return; + } + + let name = match &expr.left { + AssignTarget::Simple(SimpleAssignTarget::Ident(binding_ident)) => { + Some(binding_ident.id.sym.to_string()) + } + _ => None, + }; + + let prev_name = self.current_name.take(); + let prev_interested = self.is_interested; + self.current_name = name; + if matches!(&*expr.right, Expr::Fn(..) | Expr::Arrow(..)) { + self.is_interested = + prev_interested || self.current_name.as_deref().is_some_and(is_react_like_name); + } + + self.visit_expr(&expr.right); + + self.current_name = prev_name; + self.is_interested = prev_interested; + } + + fn visit_fn_decl(&mut self, decl: &FnDecl) { + if self.found { + return; + } + + if is_hook_like_name(&decl.ident.sym) { + self.found = true; + return; + } + + let prev_interested = self.is_interested; + self.is_interested = is_react_like_name(&decl.ident.sym); + + decl.visit_children_with(self); + + self.is_interested = prev_interested; + } + + fn visit_fn_expr(&mut self, expr: &FnExpr) { + if self.found { + return; + } + + if expr + .ident + .as_ref() + .is_some_and(|ident| is_hook_like_name(&ident.sym)) + || expr.ident.is_none() && self.current_name.as_deref().is_some_and(is_hook_like_name) + { + self.found = true; + return; + } + + let prev_interested = self.is_interested; + self.is_interested = self.is_interested + || expr + .ident + .as_ref() + .is_some_and(|ident| is_react_like_name(&ident.sym)) + || expr.ident.is_none() && self.current_name.as_deref().is_some_and(is_react_like_name); + + expr.visit_children_with(self); + + self.is_interested = prev_interested; + } + + fn visit_arrow_expr(&mut self, expr: &ArrowExpr) { + if self.found { + return; + } + + if self.current_name.as_deref().is_some_and(is_hook_like_name) { + self.found = true; + return; + } + + let prev_interested = self.is_interested; + self.is_interested = + self.is_interested || self.current_name.as_deref().is_some_and(is_react_like_name); + + expr.visit_children_with(self); + + self.is_interested = prev_interested; + } + + fn visit_call_expr(&mut self, call: &CallExpr) { + if self.found { + return; + } + + if is_memo_or_forward_ref_call(call) + && (self.is_interested || self.current_name.as_deref().is_some_and(is_react_like_name)) + { + if let Some(first_arg) = call.args.first() { + if matches!(&*first_arg.expr, Expr::Fn(_) | Expr::Arrow(_)) { + self.found = true; + return; + } + } + } + + call.visit_children_with(self); + } + + fn visit_class(&mut self, _class: &Class) { + // Skip class bodies entirely. + } + + fn visit_export_default_decl(&mut self, export: &ExportDefaultDecl) { + let prev_interested = self.is_interested; + self.is_interested = true; + + export.visit_children_with(self); + + self.is_interested = prev_interested; + } + + fn visit_export_default_expr(&mut self, export: &ExportDefaultExpr) { + let prev_interested = self.is_interested; + self.is_interested = true; + + export.visit_children_with(self); + + self.is_interested = prev_interested; + } + + fn visit_expr(&mut self, expr: &Expr) { + if self.found { + return; + } + + if self.is_interested + && matches!( + expr, + Expr::JSXMember(..) + | Expr::JSXNamespacedName(..) + | Expr::JSXEmpty(..) + | Expr::JSXElement(..) + | Expr::JSXFragment(..) + ) + { + self.found = true; + return; + } + + expr.visit_children_with(self); + } + + fn visit_stmt(&mut self, stmt: &Stmt) { + if self.found { + return; + } + + stmt.visit_children_with(self); + } +} + +fn is_memo_or_forward_ref_call(call: &CallExpr) -> bool { + match &call.callee { + Callee::Expr(expr) => match &**expr { + // Direct calls: memo(...) or forwardRef(...) + Expr::Ident(ident) => ident.sym == "memo" || ident.sym == "forwardRef", + // Member expression: React.memo(...) or React.forwardRef(...) + Expr::Member(member) => { + if let Expr::Ident(obj) = &*member.obj { + if obj.sym == "React" { + if let MemberProp::Ident(prop) = &member.prop { + return prop.sym == "memo" || prop.sym == "forwardRef"; + } + } + } + false + } + _ => false, + }, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_react_like_name() { + assert!(is_react_like_name("Component")); + assert!(is_react_like_name("MyComponent")); + assert!(is_react_like_name("A")); + assert!(is_react_like_name("useState")); + assert!(is_react_like_name("useEffect")); + assert!(is_react_like_name("use0")); + + assert!(!is_react_like_name("component")); + assert!(!is_react_like_name("myFunction")); + assert!(!is_react_like_name("use")); + assert!(!is_react_like_name("user")); + assert!(!is_react_like_name("useful")); + assert!(!is_react_like_name("")); + } +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture.rs b/crates/swc_ecma_react_compiler/tests/fixture.rs new file mode 100644 index 000000000000..c008f9b296ae --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture.rs @@ -0,0 +1,161 @@ +use std::{ + fs, + path::{Path, PathBuf}, +}; + +use react_compiler::entrypoint::plugin_options::{CompilerTarget, GatingConfig, PluginOptions}; +use serde::Deserialize; +use swc_ecma_react_compiler::{ + try_transform_source_to_code_with_parser, SourceParser, SourceSyntax, +}; +use testing::NormalizedOutput; + +#[derive(Clone, Copy, Debug, Deserialize)] +enum ParserSyntax { + #[serde(rename = "ecmascript")] + EcmaScript, + #[serde(rename = "typescript")] + TypeScript, +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ParserOptions { + syntax: Option, + jsx: Option, + tsx: Option, + decorators: Option, +} + +impl ParserOptions { + fn into_source_parser(self, default: SourceParser) -> SourceParser { + let mut parser = default; + + if let Some(syntax) = self.syntax { + parser.syntax = match syntax { + ParserSyntax::EcmaScript => SourceSyntax::EcmaScript, + ParserSyntax::TypeScript => SourceSyntax::TypeScript, + }; + } + if let Some(jsx) = self.jsx { + parser.jsx = jsx; + } + if let Some(tsx) = self.tsx { + parser.tsx = tsx; + } + if let Some(decorators) = self.decorators { + parser.decorators = decorators; + } + + parser + } +} + +#[derive(Clone, Debug, Default, Deserialize)] +#[serde(rename_all = "camelCase")] +struct FixtureOptions { + parser: Option, + filename: Option, + is_dev: Option, + compilation_mode: Option, + panic_threshold: Option, + target: Option, + gating: Option, + enable_reanimated: Option, +} + +fn default_source_parser(input: &Path) -> SourceParser { + match input.extension().and_then(|ext| ext.to_str()) { + Some("js") => SourceParser { + syntax: SourceSyntax::EcmaScript, + jsx: false, + tsx: false, + decorators: true, + }, + Some("jsx") => SourceParser { + syntax: SourceSyntax::EcmaScript, + jsx: true, + tsx: false, + decorators: true, + }, + Some("ts") => SourceParser { + syntax: SourceSyntax::TypeScript, + jsx: false, + tsx: false, + decorators: true, + }, + Some("tsx") => SourceParser { + syntax: SourceSyntax::TypeScript, + jsx: false, + tsx: true, + decorators: true, + }, + _ => SourceParser::default(), + } +} + +fn load_fixture_options(input: &Path) -> FixtureOptions { + let options_path = input.with_file_name("options.json"); + if !options_path.exists() { + return FixtureOptions::default(); + } + + serde_json::from_str( + &fs::read_to_string(&options_path).expect("failed to read fixture options"), + ) + .expect("failed to parse fixture options") +} + +fn to_plugin_options(input: &Path, source_code: &str, options: FixtureOptions) -> PluginOptions { + PluginOptions { + should_compile: true, + enable_reanimated: options.enable_reanimated.unwrap_or(false), + is_dev: options.is_dev.unwrap_or(false), + filename: options + .filename + .or_else(|| Some(input.display().to_string())), + compilation_mode: options + .compilation_mode + .unwrap_or_else(|| String::from("infer")), + panic_threshold: options + .panic_threshold + .unwrap_or_else(|| String::from("none")), + target: options + .target + .unwrap_or_else(|| CompilerTarget::Version(String::from("19"))), + gating: options.gating, + dynamic_gating: None, + no_emit: false, + output_mode: None, + eslint_suppression_rules: None, + flow_suppressions: true, + ignore_use_no_forget: false, + custom_opt_out_directives: None, + environment: Default::default(), + source_code: Some(source_code.to_string()), + profiling: false, + debug: false, + } +} + +#[testing::fixture("tests/fixture/**/input.jsx")] +#[testing::fixture("tests/fixture/**/input.ts")] +#[testing::fixture("tests/fixture/**/input.tsx")] +fn transform_fixture(input: PathBuf) { + let source = fs::read_to_string(&input).expect("failed to read input fixture"); + let output = input.with_file_name("output.js"); + let options = load_fixture_options(&input); + let parser = options + .parser + .clone() + .map(|parser| parser.into_source_parser(default_source_parser(&input))) + .unwrap_or_else(|| default_source_parser(&input)); + let plugin_options = to_plugin_options(&input, &source, options); + + let result = try_transform_source_to_code_with_parser(&source, plugin_options, parser) + .expect("transform should succeed"); + + NormalizedOutput::from(result.code) + .compare_to_file(&output) + .unwrap(); +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/directive-comments/input.jsx b/crates/swc_ecma_react_compiler/tests/fixture/directive-comments/input.jsx new file mode 100644 index 000000000000..105321741a46 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/directive-comments/input.jsx @@ -0,0 +1,6 @@ +'use client'; + +/* leading comment */ +export function Greeting({ name }) { + return
{name}
; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/directive-comments/output.js b/crates/swc_ecma_react_compiler/tests/fixture/directive-comments/output.js new file mode 100644 index 000000000000..45d2fd0b1ca8 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/directive-comments/output.js @@ -0,0 +1,16 @@ +"use client"; + +import { c as _c } from "react/compiler-runtime"; +/* leading comment */ export function Greeting(t0) { + const $ = _c(2); + const { name } = t0; + let t1; + if ($[0] !== name) { + (t1 =
{name}
); + ($[0] = name); + ($[1] = t1); + } else { + (t1 = $[1]); + } + return t1; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/gating/input.jsx b/crates/swc_ecma_react_compiler/tests/fixture/gating/input.jsx new file mode 100644 index 000000000000..9c9e75be5154 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/gating/input.jsx @@ -0,0 +1,6 @@ +import { useState } from "react"; + +export function Counter() { + const [count] = useState(0); + return {count}; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/gating/options.json b/crates/swc_ecma_react_compiler/tests/fixture/gating/options.json new file mode 100644 index 000000000000..35202b88c4d0 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/gating/options.json @@ -0,0 +1,6 @@ +{ + "gating": { + "source": "react-compiler-runtime", + "importSpecifierName": "isForgetEnabled" + } +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/gating/output.js b/crates/swc_ecma_react_compiler/tests/fixture/gating/output.js new file mode 100644 index 000000000000..d4b81ac6f827 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/gating/output.js @@ -0,0 +1,19 @@ +import { isForgetEnabled } from "react-compiler-runtime"; +import { c as _c } from "react/compiler-runtime"; +import { useState } from "react"; +export const Counter = (isForgetEnabled() ? function Counter() { + const $ = _c(2); + const [count] = useState(0); + let t0; + if ($[0] !== count) { + (t0 = {count}); + ($[0] = count); + ($[1] = t0); + } else { + (t0 = $[1]); + } + return t0; +} : function Counter() { + const [count] = useState(0); + return {count}; +}); diff --git a/crates/swc_ecma_react_compiler/tests/fixture/hook/input.ts b/crates/swc_ecma_react_compiler/tests/fixture/hook/input.ts new file mode 100644 index 000000000000..0b91a291b6bc --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/hook/input.ts @@ -0,0 +1,6 @@ +import { useState } from "react"; + +export function useCounter() { + const [count, setCount] = useState(0); + return { count, setCount }; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/hook/output.js b/crates/swc_ecma_react_compiler/tests/fixture/hook/output.js new file mode 100644 index 000000000000..526dcfcfa97c --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/hook/output.js @@ -0,0 +1,18 @@ +import { c as _c } from "react/compiler-runtime"; +import { useState } from "react"; +export function useCounter() { + const $ = _c(2); + const [count, setCount] = useState(0); + let t0; + if ($[0] !== count) { + (t0 = { + count, + setCount + }); + ($[0] = count); + ($[1] = t0); + } else { + (t0 = $[1]); + } + return t0; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/non-react-skip/input.ts b/crates/swc_ecma_react_compiler/tests/fixture/non-react-skip/input.ts new file mode 100644 index 000000000000..3b399665dc81 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/non-react-skip/input.ts @@ -0,0 +1,3 @@ +export function add(a: number, b: number) { + return a + b; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/non-react-skip/output.js b/crates/swc_ecma_react_compiler/tests/fixture/non-react-skip/output.js new file mode 100644 index 000000000000..3b399665dc81 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/non-react-skip/output.js @@ -0,0 +1,3 @@ +export function add(a: number, b: number) { + return a + b; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/simple-component/input.tsx b/crates/swc_ecma_react_compiler/tests/fixture/simple-component/input.tsx new file mode 100644 index 000000000000..11237a252eb4 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/simple-component/input.tsx @@ -0,0 +1,6 @@ +import { useState } from "react"; + +export function Counter() { + const [count] = useState(0); + return
{count}
; +} diff --git a/crates/swc_ecma_react_compiler/tests/fixture/simple-component/output.js b/crates/swc_ecma_react_compiler/tests/fixture/simple-component/output.js new file mode 100644 index 000000000000..5bed53949f7f --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/fixture/simple-component/output.js @@ -0,0 +1,15 @@ +import { c as _c } from "react/compiler-runtime"; +import { useState } from "react"; +export function Counter() { + const $ = _c(2); + const [count] = useState(0); + let t0; + if ($[0] !== count) { + (t0 =
{count}
); + ($[0] = count); + ($[1] = t0); + } else { + (t0 = $[1]); + } + return t0; +} diff --git a/crates/swc_ecma_react_compiler/tests/integration.rs b/crates/swc_ecma_react_compiler/tests/integration.rs new file mode 100644 index 000000000000..c0fc2fdb5b24 --- /dev/null +++ b/crates/swc_ecma_react_compiler/tests/integration.rs @@ -0,0 +1,605 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// +// This source code is licensed under the MIT license found in the +// LICENSE file in the root directory of this source tree. + +use react_compiler::entrypoint::plugin_options::{CompilerTarget, PluginOptions}; +use react_compiler_ast::{ + scope::{BindingKind, ScopeKind}, + statements::Statement, +}; +use swc_common::{sync::Lrc, FileName, SourceMap}; +use swc_ecma_ast::EsVersion; +use swc_ecma_parser::{parse_file_as_module, EsSyntax, Syntax}; +use swc_ecma_react_compiler::{ + convert_ast::convert_module, convert_ast_reverse::convert_program_to_swc, + convert_scope::build_scope_info, lint_source, prefilter::has_react_like_functions, + transform_source, +}; + +fn parse_module(source: &str) -> swc_ecma_ast::Module { + let cm = Lrc::new(SourceMap::default()); + let fm = cm.new_source_file(Lrc::new(FileName::Anon), source.to_string()); + let mut errors = vec![]; + parse_file_as_module( + &fm, + Syntax::Es(EsSyntax { + jsx: true, + ..Default::default() + }), + EsVersion::latest(), + None, + &mut errors, + ) + .expect("Failed to parse") +} + +fn default_options() -> PluginOptions { + PluginOptions { + should_compile: true, + enable_reanimated: false, + is_dev: false, + filename: None, + compilation_mode: "infer".to_string(), + panic_threshold: "none".to_string(), + target: CompilerTarget::Version("19".to_string()), + gating: None, + dynamic_gating: None, + no_emit: false, + output_mode: None, + eslint_suppression_rules: None, + flow_suppressions: true, + ignore_use_no_forget: false, + custom_opt_out_directives: None, + environment: Default::default(), + source_code: None, + profiling: false, + debug: false, + } +} + +// ── Prefilter tests ───────────────────────────────────────────────────────── + +#[test] +fn prefilter_detects_function_component() { + let module = parse_module("function MyComponent() { return
; }"); + assert!(has_react_like_functions(&module)); +} + +#[test] +fn prefilter_detects_arrow_component() { + let module = parse_module("const MyComponent = () =>
;"); + assert!(has_react_like_functions(&module)); +} + +#[test] +fn prefilter_detects_hook() { + let module = parse_module("function useMyHook() { return 42; }"); + assert!(has_react_like_functions(&module)); +} + +#[test] +fn prefilter_detects_hook_assigned_to_variable() { + let module = parse_module("const useMyHook = function() { return 42; };"); + assert!(has_react_like_functions(&module)); +} + +#[test] +fn prefilter_rejects_non_react_module() { + let module = parse_module( + r#" + const x = 1; + function helper() { return x + 2; } + export { helper }; + "#, + ); + assert!(!has_react_like_functions(&module)); +} + +#[test] +fn prefilter_rejects_lowercase_function() { + let module = parse_module("function myFunction() { return 42; }"); + assert!(!has_react_like_functions(&module)); +} + +#[test] +fn prefilter_rejects_use_prefix_without_uppercase() { + let module = parse_module("function useful() { return true; }"); + assert!(!has_react_like_functions(&module)); +} + +// ── AST round-trip tests ──────────────────────────────────────────────────── + +#[test] +fn convert_variable_declaration() { + let source = "const x = 1;"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::VariableDeclaration(_) + )); +} + +#[test] +fn convert_function_declaration() { + let source = "function foo() { return 42; }"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::FunctionDeclaration(_) + )); +} + +#[test] +fn convert_arrow_function_expression() { + let source = "const f = (x) => x + 1;"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::VariableDeclaration(_) + )); +} + +#[test] +fn convert_jsx_element() { + let source = "const el =
hello
;"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::VariableDeclaration(_) + )); +} + +#[test] +fn convert_import_declaration() { + let source = "import { useState } from 'react';"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::ImportDeclaration(_) + )); +} + +#[test] +fn convert_export_named_declaration() { + let source = "export const x = 1;"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::ExportNamedDeclaration(_) + )); +} + +#[test] +fn convert_export_default_declaration() { + let source = "export default function App() { return
; }"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 1); + assert!(matches!( + &file.program.body[0], + Statement::ExportDefaultDeclaration(_) + )); +} + +#[test] +fn convert_multiple_statements() { + let source = r#" + import React from 'react'; + const x = 1; + function App() { return
{x}
; } + export default App; + "#; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.body.len(), 4); + assert!(matches!( + &file.program.body[0], + Statement::ImportDeclaration(_) + )); + assert!(matches!( + &file.program.body[1], + Statement::VariableDeclaration(_) + )); + assert!(matches!( + &file.program.body[2], + Statement::FunctionDeclaration(_) + )); + assert!(matches!( + &file.program.body[3], + Statement::ExportDefaultDeclaration(_) + )); +} + +#[test] +fn convert_directive() { + let source = "'use strict';\nconst x = 1;"; + let module = parse_module(source); + let file = convert_module(&module, source); + assert_eq!(file.program.directives.len(), 1); + assert_eq!(file.program.body.len(), 1); +} + +// ── Scope analysis tests ──────────────────────────────────────────────────── + +#[test] +fn scope_program_scope_created() { + let source = "const x = 1;"; + let module = parse_module(source); + let info = build_scope_info(&module); + assert!(!info.scopes.is_empty()); + assert!(matches!(info.scopes[0].kind, ScopeKind::Program)); + assert!(info.scopes[0].parent.is_none()); +} + +#[test] +fn scope_var_hoists_to_function() { + let source = r#" + function foo() { + { + var x = 1; + } + } + "#; + let module = parse_module(source); + let info = build_scope_info(&module); + + // Find the binding for x + let x_binding = info + .bindings + .iter() + .find(|b| b.name == "x") + .expect("should find binding x"); + assert!(matches!(x_binding.kind, BindingKind::Var)); + + // x should be in a Function scope, not the Block scope + let scope = &info.scopes[x_binding.scope.0 as usize]; + assert!(matches!(scope.kind, ScopeKind::Function)); +} + +#[test] +fn scope_let_const_block_scoped() { + let source = r#" + function foo() { + { + let x = 1; + const y = 2; + } + } + "#; + let module = parse_module(source); + let info = build_scope_info(&module); + + let x_binding = info + .bindings + .iter() + .find(|b| b.name == "x") + .expect("should find binding x"); + assert!(matches!(x_binding.kind, BindingKind::Let)); + let x_scope = &info.scopes[x_binding.scope.0 as usize]; + assert!(matches!(x_scope.kind, ScopeKind::Block)); + + let y_binding = info + .bindings + .iter() + .find(|b| b.name == "y") + .expect("should find binding y"); + assert!(matches!(y_binding.kind, BindingKind::Const)); + let y_scope = &info.scopes[y_binding.scope.0 as usize]; + assert!(matches!(y_scope.kind, ScopeKind::Block)); +} + +#[test] +fn scope_function_declaration_hoists() { + let source = r#" + function outer() { + { + function inner() {} + } + } + "#; + let module = parse_module(source); + let info = build_scope_info(&module); + + let inner_binding = info + .bindings + .iter() + .find(|b| b.name == "inner") + .expect("should find binding inner"); + assert!(matches!(inner_binding.kind, BindingKind::Hoisted)); + // inner should be hoisted to the enclosing function scope (outer), not the + // block + let scope = &info.scopes[inner_binding.scope.0 as usize]; + assert!(matches!(scope.kind, ScopeKind::Function)); +} + +#[test] +fn scope_import_bindings() { + let source = r#" + import React from 'react'; + import { useState, useEffect } from 'react'; + import * as Utils from './utils'; + "#; + let module = parse_module(source); + let info = build_scope_info(&module); + + let react_binding = info + .bindings + .iter() + .find(|b| b.name == "React") + .expect("should find binding React"); + assert!(matches!(react_binding.kind, BindingKind::Module)); + assert!(react_binding.import.is_some()); + let import_data = react_binding.import.as_ref().unwrap(); + assert_eq!(import_data.source, "react"); + + let use_state_binding = info + .bindings + .iter() + .find(|b| b.name == "useState") + .expect("should find binding useState"); + assert!(matches!(use_state_binding.kind, BindingKind::Module)); + + let utils_binding = info + .bindings + .iter() + .find(|b| b.name == "Utils") + .expect("should find binding Utils"); + assert!(matches!(utils_binding.kind, BindingKind::Module)); +} + +#[test] +fn scope_nested_functions_create_scopes() { + let source = r#" + function outer(a) { + function inner(b) { + return a + b; + } + } + "#; + let module = parse_module(source); + let info = build_scope_info(&module); + + let a_binding = info + .bindings + .iter() + .find(|b| b.name == "a") + .expect("should find binding a"); + assert!(matches!(a_binding.kind, BindingKind::Param)); + + let b_binding = info + .bindings + .iter() + .find(|b| b.name == "b") + .expect("should find binding b"); + assert!(matches!(b_binding.kind, BindingKind::Param)); + + // a and b should be in different function scopes + assert!(a_binding.scope.0 != b_binding.scope.0); +} + +#[test] +fn scope_catch_clause_creates_scope() { + let source = r#" + try { + throw new Error(); + } catch (e) { + console.log(e); + } + "#; + let module = parse_module(source); + let info = build_scope_info(&module); + + let e_binding = info + .bindings + .iter() + .find(|b| b.name == "e") + .expect("should find binding e"); + assert!(matches!(e_binding.kind, BindingKind::Let)); + let scope = &info.scopes[e_binding.scope.0 as usize]; + assert!(matches!(scope.kind, ScopeKind::Catch)); +} + +#[test] +fn scope_arrow_function_params() { + let source = "const f = (x, y) => x + y;"; + let module = parse_module(source); + let info = build_scope_info(&module); + + let x_binding = info + .bindings + .iter() + .find(|b| b.name == "x") + .expect("should find binding x"); + assert!(matches!(x_binding.kind, BindingKind::Param)); + let scope = &info.scopes[x_binding.scope.0 as usize]; + assert!(matches!(scope.kind, ScopeKind::Function)); +} + +#[test] +fn scope_for_loop_creates_scope() { + let source = "for (let i = 0; i < 10; i++) { console.log(i); }"; + let module = parse_module(source); + let info = build_scope_info(&module); + + let i_binding = info + .bindings + .iter() + .find(|b| b.name == "i") + .expect("should find binding i"); + assert!(matches!(i_binding.kind, BindingKind::Let)); + let scope = &info.scopes[i_binding.scope.0 as usize]; + assert!(matches!(scope.kind, ScopeKind::For)); +} + +// ── Full transform pipeline tests ─────────────────────────────────────────── + +#[test] +fn transform_simple_component_does_not_panic() { + let source = r#" + function App() { + return
Hello
; + } + "#; + let result = transform_source(source, default_options()); + // The transform should complete without panicking. + // It may or may not produce output depending on compiler completeness. + let _ = result.module; + let _ = result.diagnostics; +} + +#[test] +fn transform_component_with_hook_does_not_panic() { + let source = r#" + import { useState } from 'react'; + function Counter() { + const [count, setCount] = useState(0); + return
{count}
; + } + "#; + let result = transform_source(source, default_options()); + let _ = result.module; + let _ = result.diagnostics; +} + +#[test] +fn transform_non_react_code_returns_none() { + let source = "const x = 1 + 2;"; + let result = transform_source(source, default_options()); + // Non-React code with compilation_mode "infer" should be skipped (prefilter) + assert!(result.module.is_none()); + assert!(result.diagnostics.is_empty()); +} + +#[test] +fn transform_compilation_mode_all_does_not_skip() { + let source = "const x = 1 + 2;"; + let mut options = default_options(); + options.compilation_mode = "all".to_string(); + let result = transform_source(source, options); + // With "all" mode, even non-React code should go through the compiler. + // It may not produce output, but it should not be skipped by prefilter. + let _ = result.module; +} + +#[test] +fn lint_simple_component_does_not_panic() { + let source = r#" + function App() { + return
Hello
; + } + "#; + let result = lint_source(source, default_options()); + let _ = result.diagnostics; +} + +#[test] +fn lint_non_react_code_returns_empty() { + let source = "const x = 1;"; + let result = lint_source(source, default_options()); + assert!(result.diagnostics.is_empty()); +} + +// ── Reverse AST conversion tests ──────────────────────────────────────────── + +#[test] +fn reverse_convert_variable_declaration() { + let source = "const x = 1;"; + let module = parse_module(source); + let file = convert_module(&module, source); + + let swc_module = convert_program_to_swc(&file); + assert_eq!(swc_module.module.body.len(), 1); + assert!(matches!( + &swc_module.module.body[0], + swc_ecma_ast::ModuleItem::Stmt(swc_ecma_ast::Stmt::Decl(swc_ecma_ast::Decl::Var(_))) + )); +} + +#[test] +fn reverse_convert_function_declaration() { + let source = "function foo() { return 42; }"; + let module = parse_module(source); + let file = convert_module(&module, source); + + let swc_module = convert_program_to_swc(&file); + assert_eq!(swc_module.module.body.len(), 1); + assert!(matches!( + &swc_module.module.body[0], + swc_ecma_ast::ModuleItem::Stmt(swc_ecma_ast::Stmt::Decl(swc_ecma_ast::Decl::Fn(_))) + )); +} + +#[test] +fn reverse_convert_import_export() { + let source = r#" + import { useState } from 'react'; + export const x = 1; + "#; + let module = parse_module(source); + let file = convert_module(&module, source); + + let swc_module = convert_program_to_swc(&file); + assert_eq!(swc_module.module.body.len(), 2); +} + +#[test] +fn reverse_convert_roundtrip_via_json() { + let source = r#" + const x = 1; + function foo(a, b) { return a + b; } + "#; + let module = parse_module(source); + let file = convert_module(&module, source); + + // Serialize to JSON and deserialize back + let json = serde_json::to_value(&file).expect("serialize to JSON"); + let deserialized: react_compiler_ast::File = + serde_json::from_value(json).expect("deserialize from JSON"); + + // Convert the deserialized AST back to SWC + let swc_module = convert_program_to_swc(&deserialized); + assert_eq!(swc_module.module.body.len(), 2); +} + +#[test] +fn reverse_convert_jsx_roundtrip() { + let source = r#"const el =
hello
;"#; + let module = parse_module(source); + let file = convert_module(&module, source); + + let json = serde_json::to_value(&file).expect("serialize to JSON"); + let deserialized: react_compiler_ast::File = + serde_json::from_value(json).expect("deserialize from JSON"); + + let swc_module = convert_program_to_swc(&deserialized); + assert_eq!(swc_module.module.body.len(), 1); +} + +#[test] +fn reverse_convert_multiple_statement_types() { + let source = r#" + import React from 'react'; + const x = 1; + let y = 'hello'; + function App() { return
{x}{y}
; } + export default App; + "#; + let module = parse_module(source); + let file = convert_module(&module, source); + + let swc_module = convert_program_to_swc(&file); + assert_eq!(swc_module.module.body.len(), 5); +} diff --git a/packages/react-compiler/package.json b/packages/react-compiler/package.json index bd89f16ec2c1..4b2c2d7e296c 100644 --- a/packages/react-compiler/package.json +++ b/packages/react-compiler/package.json @@ -52,7 +52,7 @@ "build:ts": "tsc -d", "build": "tsc -d && napi build --platform --js ./src/binding.js --dts ./src/binding.d.ts --manifest-path ../../Cargo.toml -p binding_react_compiler_node --output-dir . --release", "build:dev": "tsc -d && napi build --platform --js ./src/binding.js --dts ./src/binding.d.ts --manifest-path ../../Cargo.toml -p binding_react_compiler_node --output-dir .", - "test": "cross-env NODE_OPTIONS='--experimental-vm-modules' echo 'no test'", + "test": "node ./scripts/smoke-test.js", "version": "napi version --npm-dir scripts/npm" }, "funding": { diff --git a/packages/react-compiler/src/binding.d.ts b/packages/react-compiler/src/binding.d.ts index cc34bc26a6b2..9249cac5dc70 100644 --- a/packages/react-compiler/src/binding.d.ts +++ b/packages/react-compiler/src/binding.d.ts @@ -4,10 +4,13 @@ export declare function isReactCompilerRequired(code: Buffer, signal?: AbortSign export declare function isReactCompilerRequiredSync(code: Buffer): boolean +export declare function transform(code: Buffer, options: Buffer, signal?: AbortSignal | undefined | null): Promise + +/** Output returned by the native React Compiler binding. */ export interface TransformOutput { code: string map?: string - output?: string diagnostics: Array } +export declare function transformSync(code: Buffer, options: Buffer): TransformOutput diff --git a/packages/react-compiler/src/binding.js b/packages/react-compiler/src/binding.js index 5a39404a1b14..eee34b54a9c1 100644 --- a/packages/react-compiler/src/binding.js +++ b/packages/react-compiler/src/binding.js @@ -1,9 +1,12 @@ // prettier-ignore /* eslint-disable */ +// @ts-nocheck /* auto-generated by NAPI-RS */ -const { readFileSync } = require('fs') +const { createRequire } = require('node:module') +require = createRequire(__filename) +const { readFileSync } = require('node:fs') let nativeBinding = null const loadErrors = [] @@ -32,7 +35,11 @@ const isMuslFromFilesystem = () => { } const isMuslFromReport = () => { - const report = typeof process.report.getReport === 'function' ? process.report.getReport() : null + let report = null + if (typeof process.report?.getReport === 'function') { + process.report.excludeNetwork = true + report = process.report.getReport() + } if (!report) { return null } @@ -57,7 +64,13 @@ const isMuslFromChildProcess = () => { } function requireNative() { - if (process.platform === 'android') { + if (process.env.NAPI_RS_NATIVE_LIBRARY_PATH) { + try { + return require(process.env.NAPI_RS_NATIVE_LIBRARY_PATH); + } catch (err) { + loadErrors.push(err) + } + } else if (process.platform === 'android') { if (process.arch === 'arm64') { try { return require('./react-compiler.android-arm64.node') @@ -65,11 +78,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-android-arm64') + const binding = require('@swc/react-compiler-android-arm64') + const bindingPackageVersion = require('@swc/react-compiler-android-arm64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else if (process.arch === 'arm') { try { return require('./react-compiler.android-arm-eabi.node') @@ -77,11 +94,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-android-arm-eabi') + const binding = require('@swc/react-compiler-android-arm-eabi') + const bindingPackageVersion = require('@swc/react-compiler-android-arm-eabi/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else { loadErrors.push(new Error(`Unsupported architecture on Android ${process.arch}`)) } @@ -93,11 +114,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-win32-x64-msvc') + const binding = require('@swc/react-compiler-win32-x64-msvc') + const bindingPackageVersion = require('@swc/react-compiler-win32-x64-msvc/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else if (process.arch === 'ia32') { try { return require('./react-compiler.win32-ia32-msvc.node') @@ -105,11 +130,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-win32-ia32-msvc') + const binding = require('@swc/react-compiler-win32-ia32-msvc') + const bindingPackageVersion = require('@swc/react-compiler-win32-ia32-msvc/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else if (process.arch === 'arm64') { try { return require('./react-compiler.win32-arm64-msvc.node') @@ -117,26 +146,34 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-win32-arm64-msvc') + const binding = require('@swc/react-compiler-win32-arm64-msvc') + const bindingPackageVersion = require('@swc/react-compiler-win32-arm64-msvc/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else { loadErrors.push(new Error(`Unsupported architecture on Windows: ${process.arch}`)) } } else if (process.platform === 'darwin') { try { - return require('./react-compiler.darwin-universal.node') - } catch (e) { - loadErrors.push(e) - } - try { - return require('@swc/react-compiler-darwin-universal') - } catch (e) { - loadErrors.push(e) - } - + return require('./react-compiler.darwin-universal.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-darwin-universal') + const bindingPackageVersion = require('@swc/react-compiler-darwin-universal/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } if (process.arch === 'x64') { try { return require('./react-compiler.darwin-x64.node') @@ -144,11 +181,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-darwin-x64') + const binding = require('@swc/react-compiler-darwin-x64') + const bindingPackageVersion = require('@swc/react-compiler-darwin-x64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else if (process.arch === 'arm64') { try { return require('./react-compiler.darwin-arm64.node') @@ -156,11 +197,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-darwin-arm64') + const binding = require('@swc/react-compiler-darwin-arm64') + const bindingPackageVersion = require('@swc/react-compiler-darwin-arm64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else { loadErrors.push(new Error(`Unsupported architecture on macOS: ${process.arch}`)) } @@ -172,11 +217,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-freebsd-x64') + const binding = require('@swc/react-compiler-freebsd-x64') + const bindingPackageVersion = require('@swc/react-compiler-freebsd-x64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else if (process.arch === 'arm64') { try { return require('./react-compiler.freebsd-arm64.node') @@ -184,11 +233,15 @@ function requireNative() { loadErrors.push(e) } try { - return require('@swc/react-compiler-freebsd-arm64') + const binding = require('@swc/react-compiler-freebsd-arm64') + const bindingPackageVersion = require('@swc/react-compiler-freebsd-arm64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else { loadErrors.push(new Error(`Unsupported architecture on FreeBSD: ${process.arch}`)) } @@ -196,133 +249,259 @@ function requireNative() { if (process.arch === 'x64') { if (isMusl()) { try { - return require('./react-compiler.linux-x64-musl.node') - } catch (e) { - loadErrors.push(e) - } - try { - return require('@swc/react-compiler-linux-x64-musl') - } catch (e) { - loadErrors.push(e) - } - + return require('./react-compiler.linux-x64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-x64-musl') + const bindingPackageVersion = require('@swc/react-compiler-linux-x64-musl/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } } else { try { - return require('./react-compiler.linux-x64-gnu.node') - } catch (e) { - loadErrors.push(e) - } - try { - return require('@swc/react-compiler-linux-x64-gnu') - } catch (e) { - loadErrors.push(e) - } - + return require('./react-compiler.linux-x64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-x64-gnu') + const bindingPackageVersion = require('@swc/react-compiler-linux-x64-gnu/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } } } else if (process.arch === 'arm64') { if (isMusl()) { try { - return require('./react-compiler.linux-arm64-musl.node') - } catch (e) { - loadErrors.push(e) - } - try { - return require('@swc/react-compiler-linux-arm64-musl') - } catch (e) { - loadErrors.push(e) - } - + return require('./react-compiler.linux-arm64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-arm64-musl') + const bindingPackageVersion = require('@swc/react-compiler-linux-arm64-musl/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } } else { try { - return require('./react-compiler.linux-arm64-gnu.node') - } catch (e) { - loadErrors.push(e) - } - try { - return require('@swc/react-compiler-linux-arm64-gnu') - } catch (e) { - loadErrors.push(e) - } - + return require('./react-compiler.linux-arm64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-arm64-gnu') + const bindingPackageVersion = require('@swc/react-compiler-linux-arm64-gnu/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } } } else if (process.arch === 'arm') { if (isMusl()) { try { - return require('./react-compiler.linux-arm-musleabihf.node') - } catch (e) { - loadErrors.push(e) - } - try { - return require('@swc/react-compiler-linux-arm-musleabihf') - } catch (e) { - loadErrors.push(e) + return require('./react-compiler.linux-arm-musleabihf.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-arm-musleabihf') + const bindingPackageVersion = require('@swc/react-compiler-linux-arm-musleabihf/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./react-compiler.linux-arm-gnueabihf.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-arm-gnueabihf') + const bindingPackageVersion = require('@swc/react-compiler-linux-arm-gnueabihf/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'loong64') { + if (isMusl()) { + try { + return require('./react-compiler.linux-loong64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-loong64-musl') + const bindingPackageVersion = require('@swc/react-compiler-linux-loong64-musl/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } else { + try { + return require('./react-compiler.linux-loong64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-loong64-gnu') + const bindingPackageVersion = require('@swc/react-compiler-linux-loong64-gnu/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } } - + } else if (process.arch === 'riscv64') { + if (isMusl()) { + try { + return require('./react-compiler.linux-riscv64-musl.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-riscv64-musl') + const bindingPackageVersion = require('@swc/react-compiler-linux-riscv64-musl/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } } else { try { - return require('./react-compiler.linux-arm-gnueabihf.node') + return require('./react-compiler.linux-riscv64-gnu.node') + } catch (e) { + loadErrors.push(e) + } + try { + const binding = require('@swc/react-compiler-linux-riscv64-gnu') + const bindingPackageVersion = require('@swc/react-compiler-linux-riscv64-gnu/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding + } catch (e) { + loadErrors.push(e) + } + } + } else if (process.arch === 'ppc64') { + try { + return require('./react-compiler.linux-ppc64-gnu.node') } catch (e) { loadErrors.push(e) } try { - return require('@swc/react-compiler-linux-arm-gnueabihf') + const binding = require('@swc/react-compiler-linux-ppc64-gnu') + const bindingPackageVersion = require('@swc/react-compiler-linux-ppc64-gnu/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - - } - } else if (process.arch === 'riscv64') { - if (isMusl()) { - try { - return require('./react-compiler.linux-riscv64-musl.node') + } else if (process.arch === 's390x') { + try { + return require('./react-compiler.linux-s390x-gnu.node') } catch (e) { loadErrors.push(e) } try { - return require('@swc/react-compiler-linux-riscv64-musl') + const binding = require('@swc/react-compiler-linux-s390x-gnu') + const bindingPackageVersion = require('@swc/react-compiler-linux-s390x-gnu/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - - } else { - try { - return require('./react-compiler.linux-riscv64-gnu.node') + } else { + loadErrors.push(new Error(`Unsupported architecture on Linux: ${process.arch}`)) + } + } else if (process.platform === 'openharmony') { + if (process.arch === 'arm64') { + try { + return require('./react-compiler.openharmony-arm64.node') } catch (e) { loadErrors.push(e) } try { - return require('@swc/react-compiler-linux-riscv64-gnu') + const binding = require('@swc/react-compiler-openharmony-arm64') + const bindingPackageVersion = require('@swc/react-compiler-openharmony-arm64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - - } - } else if (process.arch === 'ppc64') { + } else if (process.arch === 'x64') { try { - return require('./react-compiler.linux-ppc64-gnu.node') + return require('./react-compiler.openharmony-x64.node') } catch (e) { loadErrors.push(e) } try { - return require('@swc/react-compiler-linux-ppc64-gnu') + const binding = require('@swc/react-compiler-openharmony-x64') + const bindingPackageVersion = require('@swc/react-compiler-openharmony-x64/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - - } else if (process.arch === 's390x') { + } else if (process.arch === 'arm') { try { - return require('./react-compiler.linux-s390x-gnu.node') + return require('./react-compiler.openharmony-arm.node') } catch (e) { loadErrors.push(e) } try { - return require('@swc/react-compiler-linux-s390x-gnu') + const binding = require('@swc/react-compiler-openharmony-arm') + const bindingPackageVersion = require('@swc/react-compiler-openharmony-arm/package.json').version + if (bindingPackageVersion !== '1.15.24' && process.env.NAPI_RS_ENFORCE_VERSION_CHECK && process.env.NAPI_RS_ENFORCE_VERSION_CHECK !== '0') { + throw new Error(`Native binding package version mismatch, expected 1.15.24 but got ${bindingPackageVersion}. You can reinstall dependencies to fix this issue.`) + } + return binding } catch (e) { loadErrors.push(e) } - } else { - loadErrors.push(new Error(`Unsupported architecture on Linux: ${process.arch}`)) + loadErrors.push(new Error(`Unsupported architecture on OpenHarmony: ${process.arch}`)) } } else { loadErrors.push(new Error(`Unsupported OS: ${process.platform}, architecture: ${process.arch}`)) @@ -332,34 +511,53 @@ function requireNative() { nativeBinding = requireNative() if (!nativeBinding || process.env.NAPI_RS_FORCE_WASI) { + let wasiBinding = null + let wasiBindingError = null try { - nativeBinding = require('./react-compiler.wasi.cjs') + wasiBinding = require('./react-compiler.wasi.cjs') + nativeBinding = wasiBinding } catch (err) { if (process.env.NAPI_RS_FORCE_WASI) { - console.error(err) + wasiBindingError = err } } if (!nativeBinding) { try { - nativeBinding = require('@swc/react-compiler-wasm32-wasi') + wasiBinding = require('@swc/react-compiler-wasm32-wasi') + nativeBinding = wasiBinding } catch (err) { if (process.env.NAPI_RS_FORCE_WASI) { - console.error(err) + wasiBindingError.cause = err + loadErrors.push(err) } } } + if (process.env.NAPI_RS_FORCE_WASI === 'error' && !wasiBinding) { + const error = new Error('WASI binding not found and NAPI_RS_FORCE_WASI is set to error') + error.cause = wasiBindingError + throw error + } } if (!nativeBinding) { if (loadErrors.length > 0) { - // TODO Link to documentation with potential fixes - // - The package owner could build/publish bindings for this arch - // - The user may need to bundle the correct files - // - The user may need to re-install node_modules to get new packages - throw new Error('Failed to load native binding', { cause: loadErrors }) + throw new Error( + `Cannot find native binding. ` + + `npm has a bug related to optional dependencies (https://github.com/npm/cli/issues/4828). ` + + 'Please try `npm i` again after removing both package-lock.json and node_modules directory.', + { + cause: loadErrors.reduce((err, cur) => { + cur.cause = err + return cur + }), + }, + ) } throw new Error(`Failed to load native binding`) } +module.exports = nativeBinding module.exports.isReactCompilerRequired = nativeBinding.isReactCompilerRequired module.exports.isReactCompilerRequiredSync = nativeBinding.isReactCompilerRequiredSync +module.exports.transform = nativeBinding.transform +module.exports.transformSync = nativeBinding.transformSync diff --git a/packages/react-compiler/src/index.ts b/packages/react-compiler/src/index.ts index 94e081ac7595..b1095832a292 100644 --- a/packages/react-compiler/src/index.ts +++ b/packages/react-compiler/src/index.ts @@ -1,18 +1,69 @@ - import * as binding from './binding' -/** - * TODO - */ -export async function isReactCompilerRequired(code: Buffer) { - return await binding.isReactCompilerRequired(code) +export type CompilerTarget = + | string + | { + kind: 'donotuse_meta_internal' + runtimeModule: string + } + +export interface GatingConfig { + source: string + importSpecifierName: string } +export type ParserSyntax = 'ecmascript' | 'typescript' -/** - * TODO - */ -export function isReactCompilerRequiredSync(code: Buffer): boolean { - return binding.isReactCompilerRequiredSync(code) +export interface ParserOptions { + syntax?: ParserSyntax + jsx?: boolean + tsx?: boolean + decorators?: boolean } +export interface TransformOptions { + parser?: ParserOptions + filename?: string + isDev?: boolean + compilationMode?: string + panicThreshold?: string + target?: CompilerTarget + gating?: GatingConfig + enableReanimated?: boolean +} + +export interface TransformOutput { + code: string + diagnostics: string[] + map?: string +} + +function normalizeInput(code: string | Buffer): Buffer { + return Buffer.isBuffer(code) ? code : Buffer.from(code) +} + +function serializeOptions(options: TransformOptions): Buffer { + return Buffer.from(JSON.stringify(options)) +} + +export async function transform( + code: string | Buffer, + options: TransformOptions = {} +): Promise { + return await binding.transform(normalizeInput(code), serializeOptions(options)) +} + +export function transformSync( + code: string | Buffer, + options: TransformOptions = {} +): TransformOutput { + return binding.transformSync(normalizeInput(code), serializeOptions(options)) +} + +export async function isReactCompilerRequired(code: string | Buffer): Promise { + return await binding.isReactCompilerRequired(normalizeInput(code)) +} + +export function isReactCompilerRequiredSync(code: string | Buffer): boolean { + return binding.isReactCompilerRequiredSync(normalizeInput(code)) +}