diff --git a/crates/compiler/src/compiler.rs b/crates/compiler/src/compiler.rs index aab50b428b..d4324aee49 100644 --- a/crates/compiler/src/compiler.rs +++ b/crates/compiler/src/compiler.rs @@ -456,6 +456,8 @@ impl Compiler { self.do_pass::(())?; + self.do_pass::(())?; + self.do_pass::(())?; Ok(abis) diff --git a/crates/passes/src/lib.rs b/crates/passes/src/lib.rs index e9fd51b085..ba14619458 100644 --- a/crates/passes/src/lib.rs +++ b/crates/passes/src/lib.rs @@ -94,6 +94,9 @@ pub use static_single_assignment::*; mod ssa_const_propagation; pub use ssa_const_propagation::*; +mod storage_read_forwarding; +pub use storage_read_forwarding::*; + mod storage_lowering; pub use storage_lowering::*; diff --git a/crates/passes/src/storage_read_forwarding/ast.rs b/crates/passes/src/storage_read_forwarding/ast.rs new file mode 100644 index 0000000000..ebb2002149 --- /dev/null +++ b/crates/passes/src/storage_read_forwarding/ast.rs @@ -0,0 +1,677 @@ +// Copyright (C) 2019-2026 Provable Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use super::StorageReadForwardingVisitor; + +use leo_ast::*; + +impl AstReconstructor for StorageReadForwardingVisitor<'_> { + type AdditionalInput = (); + type AdditionalOutput = (); + + fn reconstruct_path(&mut self, input: Path, _additional: &()) -> (Expression, Self::AdditionalOutput) { + if let Some(alias) = input.try_local_symbol().and_then(|name| self.local_alias(name)) { + let ty = self.state.type_table.get(&input.id()); + let path = Path::from(Identifier::new(alias, self.state.node_builder.next_id())).to_local(); + if let Some(ty) = ty { + self.state.type_table.insert(path.id(), ty); + } + (path.into(), ()) + } else { + (input.into(), ()) + } + } + + fn reconstruct_intrinsic( + &mut self, + mut input: IntrinsicExpression, + _additional: &(), + ) -> (Expression, Self::AdditionalOutput) { + input.arguments = input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg, &()).0).collect(); + if Self::is_effect_boundary(&input) { + self.clear_reads(); + } + (input.into(), ()) + } + + fn reconstruct_call( + &mut self, + mut input: CallExpression, + _additional: &(), + ) -> (Expression, Self::AdditionalOutput) { + input.arguments = input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg, &()).0).collect(); + self.clear_reads(); + (input.into(), ()) + } + + fn reconstruct_dynamic_op( + &mut self, + mut input: DynamicOpExpression, + _additional: &(), + ) -> (Expression, Self::AdditionalOutput) { + input.interface = self.reconstruct_type(input.interface).0; + input.target_program = self.reconstruct_expression(input.target_program, &()).0; + input.network = input.network.map(|network| self.reconstruct_expression(network, &()).0); + match &mut input.kind { + DynamicOpKind::Call { arguments, .. } | DynamicOpKind::Op { arguments, .. } => { + *arguments = + std::mem::take(arguments).into_iter().map(|arg| self.reconstruct_expression(arg, &()).0).collect(); + } + DynamicOpKind::Read { .. } => {} + } + self.clear_reads(); + (input.into(), ()) + } + + fn reconstruct_async( + &mut self, + mut input: AsyncExpression, + _additional: &(), + ) -> (Expression, Self::AdditionalOutput) { + let previous_context = self.in_finalize_context; + self.in_finalize_context = true; + self.clear_function_state(); + input.block = self.reconstruct_block(input.block).0; + self.clear_function_state(); + self.in_finalize_context = previous_context; + (input.into(), ()) + } + + fn reconstruct_conditional(&mut self, mut input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) { + input.condition = self.reconstruct_expression(input.condition, &()).0; + let join_condition = input.condition.clone(); + + let aliases = self.aliases.clone(); + self.clear_reads(); + self.clear_join_aliases(); + self.aliases = aliases.clone(); + input.then = self.reconstruct_block(input.then).0; + let then_aliases = self.aliases.clone(); + + self.clear_reads(); + self.clear_join_aliases(); + self.aliases = aliases.clone(); + input.otherwise = input.otherwise.map(|statement| Box::new(self.reconstruct_statement(*statement).0)); + let otherwise_aliases = self.aliases.clone(); + + self.clear_reads(); + self.aliases = aliases; + + self.then_join_aliases = then_aliases + .into_iter() + .filter(|(alias, target)| self.aliases.get(alias).copied() != Some(*target)) + .collect(); + self.otherwise_join_aliases = otherwise_aliases + .into_iter() + .filter(|(alias, target)| self.aliases.get(alias).copied() != Some(*target)) + .collect(); + self.join_condition = + (!self.then_join_aliases.is_empty() || !self.otherwise_join_aliases.is_empty()).then_some(join_condition); + + (input.into(), ()) + } + + fn reconstruct_ternary( + &mut self, + mut input: TernaryExpression, + _additional: &(), + ) -> (Expression, Self::AdditionalOutput) { + if !self.same_join_condition(&input.condition) { + self.clear_join_aliases(); + input.condition = self.reconstruct_expression(input.condition, &()).0; + input.if_true = self.reconstruct_expression(input.if_true, &()).0; + input.if_false = self.reconstruct_expression(input.if_false, &()).0; + return (input.into(), ()); + } + + let aliases = self.aliases.clone(); + input.condition = self.reconstruct_expression(input.condition, &()).0; + + self.aliases.extend(self.then_join_aliases.clone()); + input.if_true = self.reconstruct_expression(input.if_true, &()).0; + self.aliases = aliases.clone(); + + self.aliases.extend(self.otherwise_join_aliases.clone()); + input.if_false = self.reconstruct_expression(input.if_false, &()).0; + self.aliases = aliases; + + (input.into(), ()) + } + + fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) { + self.clear_join_aliases(); + input.variant = match input.variant { + AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr, &()).0), + AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq( + self.reconstruct_expression(left, &()).0, + self.reconstruct_expression(right, &()).0, + ), + AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq( + self.reconstruct_expression(left, &()).0, + self.reconstruct_expression(right, &()).0, + ), + }; + + self.clear_reads(); + (input.into(), ()) + } + + fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) { + self.clear_join_aliases(); + input.type_ = self.reconstruct_type(input.type_).0; + input.value = self.reconstruct_expression(input.value, &()).0; + (input.into(), ()) + } + + fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) { + if !self.is_matching_join_ternary(&input.value) { + self.clear_join_aliases(); + } + + input.value = self.reconstruct_expression(input.value, &()).0; + + let DefinitionPlace::Single(place) = &input.place else { + return (input.into(), ()); + }; + + if self.in_finalize_context + && let Expression::Path(path) = &input.value + && let Some(target) = path.try_local_symbol() + { + self.insert_alias(place.name, target); + } + + if self.in_finalize_context + && let Expression::Intrinsic(intrinsic) = &input.value + && let Some(read) = self.storage_read(intrinsic) + { + if let Some(existing) = self.reads.get(&read).copied() { + let existing = self.canonical_local(existing); + self.insert_alias(place.name, existing); + input.value = self.local_expression_like(existing, &input.value); + return (input.into(), ()); + } + self.reads.insert(read, place.name); + } + + (input.into(), ()) + } + + fn reconstruct_expression_statement( + &mut self, + mut input: ExpressionStatement, + ) -> (Statement, Self::AdditionalOutput) { + self.clear_join_aliases(); + input.expression = self.reconstruct_expression(input.expression, &()).0; + (input.into(), ()) + } + + fn reconstruct_return(&mut self, mut input: ReturnStatement) -> (Statement, Self::AdditionalOutput) { + self.clear_join_aliases(); + input.expression = self.reconstruct_expression(input.expression, &()).0; + (input.into(), ()) + } + + fn reconstruct_assign(&mut self, _input: AssignStatement) -> (Statement, Self::AdditionalOutput) { + panic!("`AssignStatement`s should not exist in the AST at this phase of compilation."); + } + + fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) { + panic!("`IterationStatement`s should not exist in the AST at this phase of compilation."); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::CompilerState; + + use leo_span::{Symbol, create_session_if_not_set_then, sym}; + use std::rc::Rc; + + fn ident(state: &mut CompilerState, name: &str) -> Identifier { + Identifier::new(Symbol::intern(name), state.node_builder.next_id()) + } + + fn local(state: &mut CompilerState, name: &str) -> Expression { + Path::from(ident(state, name)).to_local().into() + } + + fn u8_lit(state: &mut CompilerState, value: &str) -> Expression { + Literal::integer(IntegerType::U8, value.into(), Default::default(), state.node_builder.next_id()).into() + } + + fn definition(state: &mut CompilerState, name: &str, value: Expression) -> Statement { + DefinitionStatement { + place: DefinitionPlace::Single(ident(state, name)), + type_: None, + value, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into() + } + + fn storage_read(state: &mut CompilerState) -> Expression { + storage_read_with_key(state, "key") + } + + fn storage_read_with_key(state: &mut CompilerState, key: &str) -> Expression { + let arguments = vec![local(state, "data"), local(state, key), u8_lit(state, "0")]; + IntrinsicExpression { + name: sym::_mapping_get_or_use, + type_parameters: Vec::new(), + input_types: Vec::new(), + return_types: Vec::new(), + arguments, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into() + } + + fn storage_contains(state: &mut CompilerState) -> Expression { + let arguments = vec![local(state, "data"), local(state, "key")]; + IntrinsicExpression { + name: sym::_mapping_contains, + type_parameters: Vec::new(), + input_types: Vec::new(), + return_types: Vec::new(), + arguments, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into() + } + + #[test] + fn preserves_branch_alias_for_ssa_join_operand() { + create_session_if_not_set_then(|_| { + let mut state = CompilerState { node_builder: Rc::new(NodeBuilder::default()), ..Default::default() }; + + let condition = local(&mut state, "flag"); + let true_operand = local(&mut state, "x2"); + let false_operand = local(&mut state, "x0"); + let join = TernaryExpression { + condition: local(&mut state, "flag"), + if_true: true_operand, + if_false: false_operand, + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let initial_value = u8_lit(&mut state, "0"); + let first_read = storage_read(&mut state); + let second_read = storage_read(&mut state); + let block = Block { + statements: vec![ + definition(&mut state, "x0", initial_value), + ConditionalStatement { + condition, + then: Block { + statements: vec![ + definition(&mut state, "x1", first_read), + definition(&mut state, "x2", second_read), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }, + otherwise: None, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into(), + definition(&mut state, "x3", join.into()), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let mut visitor = StorageReadForwardingVisitor { + state: &mut state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: true, + }; + + let output = visitor.reconstruct_block(block).0.to_string(); + + assert_eq!( + output.matches("_mapping_get_or_use").count(), + 1, + "expected the repeated branch read to be forwarded:\n{output}" + ); + assert!( + output.contains("let x3 = flag ? x1 : x0"), + "expected the SSA join to use the surviving branch definition:\n{output}" + ); + assert!( + output.contains("let x2 = x1"), + "expected the repeated read definition to remain as a copy:\n{output}" + ); + }); + } + + #[test] + fn does_not_apply_branch_alias_to_different_condition() { + create_session_if_not_set_then(|_| { + let mut state = CompilerState { node_builder: Rc::new(NodeBuilder::default()), ..Default::default() }; + + let condition = local(&mut state, "flag"); + let unrelated_join = TernaryExpression { + condition: local(&mut state, "other"), + if_true: local(&mut state, "x2"), + if_false: local(&mut state, "x0"), + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let initial_value = u8_lit(&mut state, "0"); + let first_read = storage_read(&mut state); + let second_read = storage_read(&mut state); + let block = Block { + statements: vec![ + definition(&mut state, "x0", initial_value), + ConditionalStatement { + condition, + then: Block { + statements: vec![ + definition(&mut state, "x1", first_read), + definition(&mut state, "x2", second_read), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }, + otherwise: None, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into(), + definition(&mut state, "x3", unrelated_join.into()), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let mut visitor = StorageReadForwardingVisitor { + state: &mut state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: true, + }; + + let output = visitor.reconstruct_block(block).0.to_string(); + + assert!( + output.contains("let x3 = other ? x2 : x0"), + "branch-local alias was applied outside the matching SSA join:\n{output}" + ); + assert!( + output.contains("let x2 = x1"), + "non-matching join kept a reference to x2 without preserving its definition:\n{output}" + ); + }); + } + + #[test] + fn does_not_apply_branch_alias_after_intervening_definition() { + create_session_if_not_set_then(|_| { + let mut state = CompilerState { node_builder: Rc::new(NodeBuilder::default()), ..Default::default() }; + + let condition = local(&mut state, "flag"); + let join = TernaryExpression { + condition: local(&mut state, "flag"), + if_true: local(&mut state, "x2"), + if_false: local(&mut state, "x0"), + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let initial_value = u8_lit(&mut state, "0"); + let first_read = storage_read(&mut state); + let second_read = storage_read(&mut state); + let intervening_value = u8_lit(&mut state, "1"); + let block = Block { + statements: vec![ + definition(&mut state, "x0", initial_value), + ConditionalStatement { + condition, + then: Block { + statements: vec![ + definition(&mut state, "x1", first_read), + definition(&mut state, "x2", second_read), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }, + otherwise: None, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into(), + definition(&mut state, "tmp", intervening_value), + definition(&mut state, "x3", join.into()), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let mut visitor = StorageReadForwardingVisitor { + state: &mut state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: true, + }; + + let output = visitor.reconstruct_block(block).0.to_string(); + + assert!( + output.contains("let x3 = flag ? x2 : x0"), + "branch-local alias was applied after a non-join definition:\n{output}" + ); + assert!( + output.contains("let x2 = x1"), + "later same-condition join kept a reference to x2 without preserving its definition:\n{output}" + ); + }); + } + + #[test] + fn does_not_apply_then_branch_alias_to_otherwise_join_operand() { + create_session_if_not_set_then(|_| { + let mut state = CompilerState { node_builder: Rc::new(NodeBuilder::default()), ..Default::default() }; + + let condition = local(&mut state, "flag"); + let join = TernaryExpression { + condition: local(&mut state, "flag"), + if_true: local(&mut state, "x0"), + if_false: local(&mut state, "x2"), + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let initial_value = u8_lit(&mut state, "0"); + let first_read = storage_read(&mut state); + let second_read = storage_read(&mut state); + let block = Block { + statements: vec![ + definition(&mut state, "x0", initial_value), + ConditionalStatement { + condition, + then: Block { + statements: vec![ + definition(&mut state, "x1", first_read), + definition(&mut state, "x2", second_read), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }, + otherwise: None, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into(), + definition(&mut state, "x3", join.into()), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let mut visitor = StorageReadForwardingVisitor { + state: &mut state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: true, + }; + + let output = visitor.reconstruct_block(block).0.to_string(); + + assert!( + output.contains("let x3 = flag ? x0 : x2"), + "then-branch alias was applied to the opposite ternary arm:\n{output}" + ); + assert!( + output.contains("let x2 = x1"), + "opposite-arm join kept a reference to x2 without preserving its definition:\n{output}" + ); + }); + } + + #[test] + fn canonicalizes_aliased_storage_read_keys() { + create_session_if_not_set_then(|_| { + let mut state = CompilerState { node_builder: Rc::new(NodeBuilder::default()), ..Default::default() }; + + let key_alias = local(&mut state, "key"); + let first_read = storage_read_with_key(&mut state, "key"); + let second_read = storage_read_with_key(&mut state, "key2"); + let block = Block { + statements: vec![ + definition(&mut state, "key2", key_alias), + definition(&mut state, "x1", first_read), + definition(&mut state, "x2", second_read), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let mut visitor = StorageReadForwardingVisitor { + state: &mut state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: true, + }; + + let output = visitor.reconstruct_block(block).0.to_string(); + + assert_eq!( + output.matches("_mapping_get_or_use").count(), + 1, + "expected aliased storage-read keys to share the same read fact:\n{output}" + ); + assert!( + output.contains("let x2 = x1"), + "expected the aliased-key duplicate read to remain as a copy:\n{output}" + ); + }); + } + + #[test] + fn canonicalizes_aliased_join_conditions() { + create_session_if_not_set_then(|_| { + let mut state = CompilerState { node_builder: Rc::new(NodeBuilder::default()), ..Default::default() }; + + let join = TernaryExpression { + condition: local(&mut state, "cond2"), + if_true: local(&mut state, "x2"), + if_false: local(&mut state, "x0"), + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let initial_value = u8_lit(&mut state, "0"); + let first_condition = storage_contains(&mut state); + let second_condition = storage_contains(&mut state); + let first_read = storage_read(&mut state); + let second_read = storage_read(&mut state); + let block = Block { + statements: vec![ + definition(&mut state, "x0", initial_value), + definition(&mut state, "cond1", first_condition), + definition(&mut state, "cond2", second_condition), + ConditionalStatement { + condition: local(&mut state, "cond2"), + then: Block { + statements: vec![ + definition(&mut state, "x1", first_read), + definition(&mut state, "x2", second_read), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }, + otherwise: None, + span: Default::default(), + id: state.node_builder.next_id(), + } + .into(), + definition(&mut state, "x3", join.into()), + ], + span: Default::default(), + id: state.node_builder.next_id(), + }; + + let mut visitor = StorageReadForwardingVisitor { + state: &mut state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: true, + }; + + let output = visitor.reconstruct_block(block).0.to_string(); + + assert!( + output.contains("let cond2 = cond1"), + "expected the repeated branch condition read to remain as a copy:\n{output}" + ); + assert!( + output.contains("let x3 = cond1 ? x1 : x0"), + "expected same-condition SSA join matching to use canonical condition aliases:\n{output}" + ); + assert!( + output.contains("let x2 = x1"), + "expected the branch-local repeated read to remain as a copy:\n{output}" + ); + }); + } +} diff --git a/crates/passes/src/storage_read_forwarding/mod.rs b/crates/passes/src/storage_read_forwarding/mod.rs new file mode 100644 index 0000000000..cc7f211f97 --- /dev/null +++ b/crates/passes/src/storage_read_forwarding/mod.rs @@ -0,0 +1,74 @@ +// Copyright (C) 2019-2026 Provable Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +//! Forward repeated local finalize storage reads until a conservative effect boundary. +//! +//! This pass only handles lowered static mapping intrinsics. Dynamic storage +//! reads carry target program and network operands and are left to a separate +//! optimization so their invalidation model can be reviewed independently. +//! +//! Read facts are always cleared at branch joins. Aliases created inside a +//! branch are only exposed to the matching arm of same-condition SSA join +//! ternaries emitted immediately after the branch, so branch-local definitions +//! are not treated as globally available after the join. Pending branch aliases +//! are discarded as soon as a non-join statement is encountered. +//! +//! Duplicate storage reads are rewritten to explicit local copies instead of +//! removed in this pass. This keeps the transformed AST well-formed even when a +//! later path-sensitive use is intentionally not rewritten; the following DCE +//! pass is responsible for deleting copies whose uses were fully forwarded. + +use crate::{CompilerState, Pass}; + +use leo_ast::UnitReconstructor as _; +use leo_errors::Result; + +mod ast; +mod program; +mod visitor; + +use visitor::StorageReadForwardingVisitor; + +pub struct StorageReadForwarding; + +impl Pass for StorageReadForwarding { + type Input = (); + type Output = (); + + const NAME: &str = "StorageReadForwarding"; + + fn do_pass(_input: Self::Input, state: &mut CompilerState) -> Result { + let ast = std::mem::take(&mut state.ast); + let mut visitor = StorageReadForwardingVisitor { + state, + reads: Default::default(), + aliases: Default::default(), + then_join_aliases: Default::default(), + otherwise_join_aliases: Default::default(), + join_condition: None, + in_finalize_context: false, + }; + + let ast = ast.map( + |program| visitor.reconstruct_program(program), + |library| library, // no-op for libraries + ); + + visitor.state.handler.last_err()?; + visitor.state.ast = ast; + Ok(()) + } +} diff --git a/crates/passes/src/storage_read_forwarding/program.rs b/crates/passes/src/storage_read_forwarding/program.rs new file mode 100644 index 0000000000..2c21b1f996 --- /dev/null +++ b/crates/passes/src/storage_read_forwarding/program.rs @@ -0,0 +1,56 @@ +// Copyright (C) 2019-2026 Provable Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use super::StorageReadForwardingVisitor; + +use leo_ast::{AstReconstructor, Constructor, Function, Library, Module, UnitReconstructor}; + +impl UnitReconstructor for StorageReadForwardingVisitor<'_> { + fn reconstruct_library(&mut self, input: Library) -> Library { + input + } + + fn reconstruct_program_scope(&mut self, mut input: leo_ast::ProgramScope) -> leo_ast::ProgramScope { + input.functions = input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(); + input.constructor = input.constructor.map(|c| self.reconstruct_constructor(c)); + input + } + + fn reconstruct_function(&mut self, mut input: Function) -> Function { + let previous_context = self.in_finalize_context; + self.in_finalize_context = input.variant.is_finalize_context(); + self.clear_function_state(); + input.block = self.reconstruct_block(input.block).0; + self.clear_function_state(); + self.in_finalize_context = previous_context; + input + } + + fn reconstruct_constructor(&mut self, mut input: Constructor) -> Constructor { + let previous_context = self.in_finalize_context; + self.in_finalize_context = true; + self.clear_function_state(); + input.block = self.reconstruct_block(input.block).0; + self.clear_function_state(); + self.in_finalize_context = previous_context; + input + } + + fn reconstruct_module(&mut self, mut input: Module) -> Module { + input.functions = input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(); + input + } +} diff --git a/crates/passes/src/storage_read_forwarding/visitor.rs b/crates/passes/src/storage_read_forwarding/visitor.rs new file mode 100644 index 0000000000..66b55917b8 --- /dev/null +++ b/crates/passes/src/storage_read_forwarding/visitor.rs @@ -0,0 +1,149 @@ +// Copyright (C) 2019-2026 Provable Inc. +// This file is part of the Leo library. + +// The Leo library is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// The Leo library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. + +// You should have received a copy of the GNU General Public License +// along with the Leo library. If not, see . + +use crate::CompilerState; + +use leo_ast::{Expression, Identifier, IntrinsicExpression, LiteralVariant, Location, Node as _, Path}; +use leo_span::{Symbol, sym}; + +use indexmap::IndexMap; + +#[derive(Clone, Eq, PartialEq, Hash)] +pub(super) enum Atom { + Local(Symbol), + Global(Location), + Literal(LiteralVariant), +} + +#[derive(Eq, PartialEq, Hash)] +pub(super) enum StorageRead { + Get { mapping: Atom, key: Atom }, + GetOrUse { mapping: Atom, key: Atom, default: Atom }, + Contains { mapping: Atom, key: Atom }, +} + +pub struct StorageReadForwardingVisitor<'a> { + pub state: &'a mut CompilerState, + pub(super) reads: IndexMap, + pub(super) aliases: IndexMap, + pub(super) then_join_aliases: IndexMap, + pub(super) otherwise_join_aliases: IndexMap, + pub(super) join_condition: Option, + pub(super) in_finalize_context: bool, +} + +impl StorageReadForwardingVisitor<'_> { + pub(super) fn clear_reads(&mut self) { + self.reads.clear(); + } + + pub(super) fn clear_function_state(&mut self) { + self.reads.clear(); + self.aliases.clear(); + self.clear_join_aliases(); + } + + pub(super) fn clear_join_aliases(&mut self) { + self.then_join_aliases.clear(); + self.otherwise_join_aliases.clear(); + self.join_condition = None; + } + + pub(super) fn local_alias(&self, name: Symbol) -> Option { + let mut current = name; + while let Some(next) = self.aliases.get(¤t).copied() { + if next == current { + return Some(current); + } + current = next; + } + (current != name).then_some(current) + } + + pub(super) fn canonical_local(&self, name: Symbol) -> Symbol { + self.local_alias(name).unwrap_or(name) + } + + pub(super) fn insert_alias(&mut self, alias: Symbol, target: Symbol) { + let target = self.canonical_local(target); + if alias != target { + self.aliases.insert(alias, target); + } + } + + pub(super) fn same_join_condition(&self, condition: &Expression) -> bool { + let Some(join_condition) = &self.join_condition else { + return false; + }; + + match (condition, join_condition) { + (Expression::Path(left), Expression::Path(right)) => { + let left = left.try_local_symbol().map(|name| self.canonical_local(name)); + let right = right.try_local_symbol().map(|name| self.canonical_local(name)); + left == right && left.is_some() + } + (Expression::Literal(left), Expression::Literal(right)) => left.variant == right.variant, + _ => false, + } + } + + pub(super) fn is_matching_join_ternary(&self, expression: &Expression) -> bool { + matches!(expression, Expression::Ternary(ternary) if self.same_join_condition(&ternary.condition)) + } + + pub(super) fn atom(&self, expr: &Expression) -> Option { + match expr { + Expression::Literal(lit) => Some(Atom::Literal(lit.variant.clone())), + Expression::Path(path) => path + .try_local_symbol() + .map(|name| Atom::Local(self.canonical_local(name))) + .or_else(|| path.try_global_location().cloned().map(Atom::Global)), + _ => None, + } + } + + pub(super) fn storage_read(&self, intrinsic: &IntrinsicExpression) -> Option { + match intrinsic.name { + sym::_mapping_get => Some(StorageRead::Get { + mapping: self.atom(intrinsic.arguments.first()?)?, + key: self.atom(intrinsic.arguments.get(1)?)?, + }), + sym::_mapping_get_or_use => Some(StorageRead::GetOrUse { + mapping: self.atom(intrinsic.arguments.first()?)?, + key: self.atom(intrinsic.arguments.get(1)?)?, + default: self.atom(intrinsic.arguments.get(2)?)?, + }), + sym::_mapping_contains => Some(StorageRead::Contains { + mapping: self.atom(intrinsic.arguments.first()?)?, + key: self.atom(intrinsic.arguments.get(1)?)?, + }), + _ => None, + } + } + + pub(super) fn local_expression_like(&mut self, symbol: Symbol, old_value: &Expression) -> Expression { + let ty = self.state.type_table.get(&old_value.id()); + let path = Path::from(Identifier::new(symbol, self.state.node_builder.next_id())).to_local(); + if let Some(ty) = ty { + self.state.type_table.insert(path.id(), ty); + } + path.into() + } + + pub(super) fn is_effect_boundary(intrinsic: &IntrinsicExpression) -> bool { + matches!(intrinsic.name, sym::_mapping_set | sym::_mapping_remove | sym::_final_run) + } +} diff --git a/crates/passes/src/test_passes.rs b/crates/passes/src/test_passes.rs index 640d86cb88..d4fa3212a8 100644 --- a/crates/passes/src/test_passes.rs +++ b/crates/passes/src/test_passes.rs @@ -211,6 +211,18 @@ macro_rules! compiler_passes { (Disambiguate, ()), (StorageLowering, (TypeCheckingInput::new(NetworkName::TestnetV0))) ]), + (storage_read_forwarding_runner, [ + (GlobalVarsCollection, ()), + (PathResolution, ()), + (GlobalItemsCollection, ()), + (TypeChecking, (TypeCheckingInput::new(NetworkName::TestnetV0))), + (Disambiguate, ()), + (StorageLowering, (TypeCheckingInput::new(NetworkName::TestnetV0))), + (SsaForming, (SsaFormingInput { rename_defs: true })), + (Flattening, ()), + (SsaForming, (SsaFormingInput { rename_defs: false })), + (StorageReadForwarding, ()) + ]), (write_transforming_runner, [ (GlobalVarsCollection, ()), (PathResolution, ()), diff --git a/tests/expectations/cli/test_ast_snapshots_program/STDOUT b/tests/expectations/cli/test_ast_snapshots_program/STDOUT index 9b67a5d9dd..0b555e57c0 100644 --- a/tests/expectations/cli/test_ast_snapshots_program/STDOUT +++ b/tests/expectations/cli/test_ast_snapshots_program/STDOUT @@ -56,6 +56,8 @@ StaticAnalyzing.ast StaticAnalyzing.json StorageLowering.ast StorageLowering.json +StorageReadForwarding.ast +StorageReadForwarding.json TypeChecking.ast TypeChecking.json WriteTransforming.ast @@ -108,6 +110,8 @@ StaticAnalyzing.ast StaticAnalyzing.json StorageLowering.ast StorageLowering.json +StorageReadForwarding.ast +StorageReadForwarding.json TypeChecking.ast TypeChecking.json WriteTransforming.ast diff --git a/tests/expectations/compiler/storage/external_same_named_mapping_reads.out b/tests/expectations/compiler/storage/external_same_named_mapping_reads.out new file mode 100644 index 0000000000..16a28bc74a --- /dev/null +++ b/tests/expectations/compiler/storage/external_same_named_mapping_reads.out @@ -0,0 +1,54 @@ +program left.aleo; + +mapping data: + key as u32.public; + value as u8.public; + +function noop: + async noop into r0; + output r0 as left.aleo/noop.future; + +finalize noop: + assert.eq true true; + +constructor: + assert.eq edition 0u16; +// --- Next Program --- // +program right.aleo; + +mapping data: + key as u32.public; + value as u8.public; + +function noop: + async noop into r0; + output r0 as right.aleo/noop.future; + +finalize noop: + assert.eq true true; + +constructor: + assert.eq edition 0u16; +// --- Next Program --- // +import left.aleo; +import right.aleo; +program caller.aleo; + +mapping out: + key as boolean.public; + value as u8.public; + +function read_both: + input r0 as u32.private; + async read_both r0 into r1; + output r1 as caller.aleo/read_both.future; + +finalize read_both: + input r0 as u32.public; + get.or_use left.aleo/data[r0] 1u8 into r1; + get.or_use right.aleo/data[r0] 2u8 into r2; + add r1 r2 into r3; + set r3 into out[false]; + +constructor: + assert.eq edition 0u16; diff --git a/tests/expectations/passes/storage_read_forwarding/assert_barrier.out b/tests/expectations/passes/storage_read_forwarding/assert_barrier.out new file mode 100644 index 0000000000..5fcdde411a --- /dev/null +++ b/tests/expectations/passes/storage_read_forwarding/assert_barrier.out @@ -0,0 +1,19 @@ +program test.aleo { + @noupgrade + async constructor() { + return; + } + mapping data: u32 => u8; + mapping out: bool => u8; + fn keep_after_assert(key$$0: u32) -> Final { + let $var$1 = async { + let a = _mapping_get_or_use(test.aleo::data, key, 0u8); + let b = a; + assert_eq(a, a); + let c = _mapping_get_or_use(test.aleo::data, key, 0u8); + _mapping_set(test.aleo::out, false, a + c); + }; + return $var$1; + } +} + diff --git a/tests/expectations/passes/storage_read_forwarding/conditional_branch_reads.out b/tests/expectations/passes/storage_read_forwarding/conditional_branch_reads.out new file mode 100644 index 0000000000..c844d864e5 --- /dev/null +++ b/tests/expectations/passes/storage_read_forwarding/conditional_branch_reads.out @@ -0,0 +1,22 @@ +program test.aleo { + @noupgrade + async constructor() { + return; + } + mapping data: u32 => u8; + mapping out: bool => u8; + fn branch_local(key$$0: u32, flag$$1: bool) -> Final { + let $var$2 = async { + let condition$3 = flag; + let a = _mapping_get_or_use(test.aleo::data, key, 0u8); + let b = a; + _mapping_set(test.aleo::out, false, a + a); + { + } + let c = _mapping_get_or_use(test.aleo::data, key, 0u8); + _mapping_set(test.aleo::out, true, c); + }; + return $var$2; + } +} + diff --git a/tests/expectations/passes/storage_read_forwarding/repeated_mapping_reads.out b/tests/expectations/passes/storage_read_forwarding/repeated_mapping_reads.out new file mode 100644 index 0000000000..1beed075c2 --- /dev/null +++ b/tests/expectations/passes/storage_read_forwarding/repeated_mapping_reads.out @@ -0,0 +1,37 @@ +program test.aleo { + @noupgrade + async constructor() { + return; + } + mapping data: u32 => u8; + mapping out: bool => u8; + fn reuse_before_write(key$$0: u32) -> Final { + let $var$1 = async { + let a = _mapping_get_or_use(test.aleo::data, key, 0u8); + let b = a; + _mapping_set(test.aleo::out, false, a + a); + }; + return $var$1; + } + fn keep_after_write(key$$2: u32) -> Final { + let $var$3 = async { + let a = _mapping_get_or_use(test.aleo::data, key, 0u8); + _mapping_set(test.aleo::data, key, 1u8); + let b = _mapping_get_or_use(test.aleo::data, key, 0u8); + _mapping_set(test.aleo::out, false, a + b); + }; + return $var$3; + } + fn reuse_contains(key$$4: u32) -> Final { + let $var$5 = async { + let a = _mapping_contains(test.aleo::data, key); + let b = a; + let one = 1u8; + let zero = 0u8; + let value = a ? one : zero; + _mapping_set(test.aleo::out, a, value); + }; + return $var$5; + } +} + diff --git a/tests/expectations/passes/storage_read_forwarding/repeated_vector_len.out b/tests/expectations/passes/storage_read_forwarding/repeated_vector_len.out new file mode 100644 index 0000000000..db98f1a43c --- /dev/null +++ b/tests/expectations/passes/storage_read_forwarding/repeated_vector_len.out @@ -0,0 +1,29 @@ +program test.aleo { + @noupgrade + async constructor() { + return; + } + mapping out: bool => u32; + mapping data__: u32 => u8; + mapping data__len__: bool => u32; + fn reuse_len() -> Final { + let $var$1 = async { + let a = _mapping_get_or_use(test.aleo::data__len__, false, 0u32); + let b = a; + _mapping_set(test.aleo::out, false, a + a); + }; + return $var$1; + } + fn keep_after_push() -> Final { + let $var$2 = async { + let a = _mapping_get_or_use(test.aleo::data__len__, false, 0u32); + let $len_var$0 = a; + _mapping_set(test.aleo::data__len__, false, a + 1u32); + _mapping_set(test.aleo::data__, a, 1u8); + let b = _mapping_get_or_use(test.aleo::data__len__, false, 0u32); + _mapping_set(test.aleo::out, false, a + b); + }; + return $var$2; + } +} + diff --git a/tests/tests/compiler/storage/external_same_named_mapping_reads.leo b/tests/tests/compiler/storage/external_same_named_mapping_reads.leo new file mode 100644 index 0000000000..22587a5d21 --- /dev/null +++ b/tests/tests/compiler/storage/external_same_named_mapping_reads.leo @@ -0,0 +1,41 @@ +program left.aleo { + mapping data: u32 => u8; + + fn noop() -> Final { + return final {}; + } + + @noupgrade + constructor() {} +} + +// --- Next Program --- // +program right.aleo { + mapping data: u32 => u8; + + fn noop() -> Final { + return final {}; + } + + @noupgrade + constructor() {} +} + +// --- Next Program --- // +import left.aleo; +import right.aleo; + +program caller.aleo { + mapping out: bool => u8; + + fn read_both(key: u32) -> Final { + return final { + let a: u8 = Mapping::get_or_use(left.aleo::data, key, 1u8); + let b: u8 = Mapping::get_or_use(right.aleo::data, key, 2u8); + Mapping::set(out, false, a + b); + }; + } + + @noupgrade + constructor() {} +} diff --git a/tests/tests/passes/storage_read_forwarding/assert_barrier.leo b/tests/tests/passes/storage_read_forwarding/assert_barrier.leo new file mode 100644 index 0000000000..48137c33a3 --- /dev/null +++ b/tests/tests/passes/storage_read_forwarding/assert_barrier.leo @@ -0,0 +1,17 @@ +program test.aleo { + mapping data: u32 => u8; + mapping out: bool => u8; + + fn keep_after_assert(key: u32) -> Final { + return final { + let a: u8 = Mapping::get_or_use(data, key, 0u8); + let b: u8 = Mapping::get_or_use(data, key, 0u8); + assert_eq(b, a); + let c: u8 = Mapping::get_or_use(data, key, 0u8); + Mapping::set(out, false, b + c); + }; + } + + @noupgrade + constructor() {} +} diff --git a/tests/tests/passes/storage_read_forwarding/conditional_branch_reads.leo b/tests/tests/passes/storage_read_forwarding/conditional_branch_reads.leo new file mode 100644 index 0000000000..655a733188 --- /dev/null +++ b/tests/tests/passes/storage_read_forwarding/conditional_branch_reads.leo @@ -0,0 +1,20 @@ +program test.aleo { + mapping data: u32 => u8; + mapping out: bool => u8; + + fn branch_local(key: u32, flag: bool) -> Final { + return final { + if flag { + let a: u8 = Mapping::get_or_use(data, key, 0u8); + let b: u8 = Mapping::get_or_use(data, key, 0u8); + Mapping::set(out, false, a + b); + } + + let c: u8 = Mapping::get_or_use(data, key, 0u8); + Mapping::set(out, true, c); + }; + } + + @noupgrade + constructor() {} +} diff --git a/tests/tests/passes/storage_read_forwarding/repeated_mapping_reads.leo b/tests/tests/passes/storage_read_forwarding/repeated_mapping_reads.leo new file mode 100644 index 0000000000..8c0d676c09 --- /dev/null +++ b/tests/tests/passes/storage_read_forwarding/repeated_mapping_reads.leo @@ -0,0 +1,35 @@ +program test.aleo { + mapping data: u32 => u8; + mapping out: bool => u8; + + fn reuse_before_write(key: u32) -> Final { + return final { + let a: u8 = Mapping::get_or_use(data, key, 0u8); + let b: u8 = Mapping::get_or_use(data, key, 0u8); + Mapping::set(out, false, a + b); + }; + } + + fn keep_after_write(key: u32) -> Final { + return final { + let a: u8 = Mapping::get_or_use(data, key, 0u8); + Mapping::set(data, key, 1u8); + let b: u8 = Mapping::get_or_use(data, key, 0u8); + Mapping::set(out, false, a + b); + }; + } + + fn reuse_contains(key: u32) -> Final { + return final { + let a: bool = Mapping::contains(data, key); + let b: bool = Mapping::contains(data, key); + let one: u8 = 1u8; + let zero: u8 = 0u8; + let value: u8 = a ? one : zero; + Mapping::set(out, b, value); + }; + } + + @noupgrade + constructor() {} +} diff --git a/tests/tests/passes/storage_read_forwarding/repeated_vector_len.leo b/tests/tests/passes/storage_read_forwarding/repeated_vector_len.leo new file mode 100644 index 0000000000..3b80a65819 --- /dev/null +++ b/tests/tests/passes/storage_read_forwarding/repeated_vector_len.leo @@ -0,0 +1,24 @@ +program test.aleo { + storage data: [u8]; + mapping out: bool => u32; + + fn reuse_len() -> Final { + return final { + let a: u32 = data.len(); + let b: u32 = data.len(); + Mapping::set(out, false, a + b); + }; + } + + fn keep_after_push() -> Final { + return final { + let a: u32 = data.len(); + data.push(1u8); + let b: u32 = data.len(); + Mapping::set(out, false, a + b); + }; + } + + @noupgrade + constructor() {} +}