diff --git a/execution_graph/src/graph.rs b/execution_graph/src/graph.rs index e963deb..ef6a991 100644 --- a/execution_graph/src/graph.rs +++ b/execution_graph/src/graph.rs @@ -1614,6 +1614,143 @@ mod tests { assert_eq!(g.node_run_count(n), Some(2)); } + #[test] + fn host_write_invalidates_prior_readers_of_same_key() { + #[derive(Clone)] + struct KvHost { + kv: Rc>>, + get_sig: SigHash, + set_sig: SigHash, + } + + impl Host for KvHost { + fn call( + &mut self, + symbol: &str, + sig_hash: SigHash, + args: &[ValueRef<'_>], + rets: &mut [Value], + mut ctx: HostContext<'_, '_>, + ) -> Result { + match symbol { + "kv.get" => { + if sig_hash != self.get_sig { + return Err(HostError::SignatureMismatch); + } + let [ValueRef::U64(key)] = args else { + return Err(HostError::Failed); + }; + ctx.record_read(ResourceKeyRef::HostState { + op: self.get_sig, + key: *key, + }); + let v = *self.kv.borrow().get(key).unwrap_or(&0); + rets[0] = Value::I64(v); + Ok(0) + } + "kv.set" => { + if sig_hash != self.set_sig { + return Err(HostError::SignatureMismatch); + } + let [ValueRef::U64(key), ValueRef::I64(value)] = args else { + return Err(HostError::Failed); + }; + self.kv.borrow_mut().insert(*key, *value); + // Use the reader's key namespace so this write invalidates prior reads. + ctx.record_write(ResourceKeyRef::HostState { + op: self.get_sig, + key: *key, + }); + rets[0] = Value::Unit; + Ok(0) + } + _ => Err(HostError::UnknownSymbol), + } + } + } + + let get_sig = HostSig { + args: vec![ValueType::U64], + rets: vec![ValueType::I64], + }; + let set_sig = HostSig { + args: vec![ValueType::U64, ValueType::I64], + rets: vec![ValueType::Unit], + }; + let get_hash = sig_hash(&get_sig); + let set_hash = sig_hash(&set_sig); + + let mut get_builder = ProgramBuilder::new(); + let get_host = get_builder.host_sig_for("kv.get", get_sig); + let mut get_asm = Asm::new(); + get_asm.const_u64(1, 1); + get_asm.host_call(0, get_host, 0, &[1], &[2]); + get_asm.ret(0, &[2]); + let get_entry = get_builder + .push_function_checked( + get_asm, + FunctionSig { + arg_types: vec![], + ret_types: vec![ValueType::I64], + }, + ) + .unwrap(); + get_builder + .set_function_output_name(get_entry, 0, "value") + .unwrap(); + let get_prog = Arc::new(get_builder.build_verified().unwrap()); + + let mut set_builder = ProgramBuilder::new(); + let set_host = set_builder.host_sig_for("kv.set", set_sig); + let mut set_asm = Asm::new(); + set_asm.const_u64(1, 1); + set_asm.const_i64(2, 8); + set_asm.host_call(0, set_host, 0, &[1, 2], &[3]); + set_asm.ret(0, &[3]); + let set_entry = set_builder + .push_function_checked( + set_asm, + FunctionSig { + arg_types: vec![], + ret_types: vec![ValueType::Unit], + }, + ) + .unwrap(); + set_builder + .set_function_output_name(set_entry, 0, "done") + .unwrap(); + let set_prog = Arc::new(set_builder.build_verified().unwrap()); + + let kv = Rc::new(RefCell::new(BTreeMap::new())); + kv.borrow_mut().insert(1, 7); + let host = KvHost { + kv, + get_sig: get_hash, + set_sig: set_hash, + }; + + let mut g = ExecutionGraph::new(host, Limits::default()); + let reader = g.add_node(get_prog, get_entry, vec![]); + + g.run_all().unwrap(); + assert_eq!( + g.node_outputs(reader).unwrap().get("value"), + Some(&Value::I64(7)) + ); + assert_eq!(g.node_run_count(reader), Some(1)); + + let writer = g.add_node(set_prog, set_entry, vec![]); + g.run_node(writer).unwrap(); + assert_eq!(g.node_run_count(reader), Some(1)); + + g.run_all().unwrap(); + assert_eq!( + g.node_outputs(reader).unwrap().get("value"), + Some(&Value::I64(8)) + ); + assert_eq!(g.node_run_count(reader), Some(2)); + } + #[test] fn host_read_order_changes_do_not_change_last_read_ids() { #[derive(Clone)] diff --git a/execution_graph/src/tape_access.rs b/execution_graph/src/tape_access.rs index 41d427e..8c0c1dd 100644 --- a/execution_graph/src/tape_access.rs +++ b/execution_graph/src/tape_access.rs @@ -98,13 +98,13 @@ impl AccessSink for CollectingAccessSink<'_> { fn write(&mut self, key: ResourceKeyRef<'_>) { self.counter.set(self.counter.get().saturating_add(1)); - let key = match key { - ResourceKeyRef::Input(name) => ResourceKey::input(name), - ResourceKeyRef::HostState { op, key } => { - ResourceKey::host_state(HostOpId::new(op.0), key) - } - ResourceKeyRef::OpaqueHost { op } => ResourceKey::opaque_host(HostOpId::new(op.0)), - }; + let key = mark_tape_key_dirty( + self.dirty, + self.input_ids, + self.host_state_ids, + self.opaque_host_ids, + key, + ); self.log.push(Access::Write(key)); } } @@ -159,6 +159,35 @@ pub(crate) fn intern_opaque_host_key_id( id } +#[inline] +fn mark_tape_key_dirty( + dirty: &mut DirtyEngine, + input_ids: &mut BTreeMap, DirtyKey>, + host_state_ids: &mut HashMap<(HostOpId, u64), DirtyKey>, + opaque_host_ids: &mut HashMap, + key: ResourceKeyRef<'_>, +) -> ResourceKey { + match key { + ResourceKeyRef::Input(name) => { + let id = intern_input_key_id(dirty, input_ids, name); + dirty.mark_dirty(id); + ResourceKey::input(name) + } + ResourceKeyRef::HostState { op, key } => { + let op = HostOpId::new(op.0); + let id = intern_host_state_key_id(dirty, host_state_ids, op, key); + dirty.mark_dirty(id); + ResourceKey::host_state(op, key) + } + ResourceKeyRef::OpaqueHost { op } => { + let op = HostOpId::new(op.0); + let id = intern_opaque_host_key_id(dirty, opaque_host_ids, op); + dirty.mark_dirty(id); + ResourceKey::opaque_host(op) + } + } +} + /// Fast-path access sink used when per-node access log collection is disabled. /// /// It emits dependency read IDs directly into `read_ids`, avoiding intermediate `AccessLog` @@ -217,9 +246,16 @@ impl AccessSink for DepsOnlyAccessSink<'_> { } #[inline] - fn write(&mut self, _key: ResourceKeyRef<'_>) { + fn write(&mut self, key: ResourceKeyRef<'_>) { // Strict-deps mode requires host scopes to emit at least one access event. self.counter.set(self.counter.get().saturating_add(1)); + let _ = mark_tape_key_dirty( + self.dirty, + self.input_ids, + self.host_state_ids, + self.opaque_host_ids, + key, + ); } } diff --git a/execution_tape/src/host.rs b/execution_tape/src/host.rs index ae91e7e..cba50d8 100644 --- a/execution_tape/src/host.rs +++ b/execution_tape/src/host.rs @@ -134,8 +134,10 @@ use crate::value::Value; /// let [ValueRef::U64(key), ValueRef::I64(value)] = args else { /// return Err(HostError::Failed); /// }; +/// // Use the same `(op, key)` namespace as `kv.get` so this write invalidates +/// // prior reads of the key. /// ctx.record_write(ResourceKeyRef::HostState { -/// op: sig_hash, +/// op: self.get_sig, /// key: *key, /// }); /// self.kv.insert(*key, *value); @@ -364,10 +366,11 @@ impl<'vm, 'access> HostContext<'vm, 'access> { /// The string is an embedder-chosen stable name. /// /// - [`ResourceKeyRef::HostState`] is the main “precise” form for host-managed state. -/// It is explicitly namespaced by the host operation’s [`SigHash`], so different host ops can -/// reuse the same numeric `key` without colliding. The `key: u64` should identify *which* -/// piece of state was consulted/mutated for that operation (often a stable hash of a structured -/// key, or an intern id managed by the embedder). +/// It is explicitly namespaced by a stable [`SigHash`] chosen by the host, so unrelated state +/// domains can reuse the same numeric `key` without colliding. The `key: u64` should identify +/// *which* piece of state was consulted/mutated for that namespace (often a stable hash of a +/// structured key, or an intern id managed by the embedder). Writes that should invalidate +/// previous reads must use the same `(op, key)` pair those reads recorded. /// /// - [`ResourceKeyRef::OpaqueHost`] is a conservative escape hatch for operations that depend on /// (or mutate) host state but cannot (or choose not to) produce a more precise key.