diff --git a/README.md b/README.md index e766d910a..9f02f2b34 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ What Monty **can** do: - Completely block access to the host environment: filesystem, env variables and network access are all implemented via external function calls the developer can control - Call functions on the host - only functions you give it access to - Run typechecking - monty supports full modern python type hints and comes with [ty](https://docs.astral.sh/ty/) included in a single binary to run typechecking +- Run many `import numpy as np` numeric array workloads with built-in NumPy support for arrays, dtypes, broadcasting, ufunc-style math, reductions, indexing, shape manipulation and common numeric helpers - Be snapshotted to bytes at external function calls, meaning you can store the interpreter state in a file or database, and resume later - Startup extremely fast (<1μs to go from code to execution result), and has runtime performance that is similar to CPython (generally between 5x faster and 5x slower) - Be called from Rust, Python, or Javascript - because Monty has no dependencies on cpython, you can use it anywhere you can run Rust @@ -41,7 +42,7 @@ What Monty **can** do: What Monty **cannot** do: - Use the rest of the standard library -- Use third party libraries (like Pydantic), support for external python library is not a goal +- Load arbitrary third party libraries (like Pydantic). NumPy support is built into Monty; it is not general package import support - define classes (support should come soon) - use match statements (again, support should come soon) @@ -66,6 +67,30 @@ In very simple terms, the idea of all the above is that LLMs can work faster, ch Monty can be called from Python, JavaScript/TypeScript or Rust. +### Built-in NumPy support + +Monty includes a built-in `numpy` module for safe numeric array workloads. This is implemented inside the sandbox, so user code can write ordinary `import numpy as np` snippets without loading CPython's C-backed NumPy package or gaining host access. + +```python +import pydantic_monty + +code = """ +import numpy as np + +scores = np.array([[1, 2, 3], [4, 5, 6]]) +weights = np.array([10, 20, 30]) +weighted = scores * weights + +(weighted.tolist(), weighted.sum(), scores.mean(), str(scores.astype('float').dtype)) +""" + +m = pydantic_monty.Monty(code) +print(m.run()) +#> ([[10, 40, 90], [40, 100, 180]], 460, 3.5, 'float64') +``` + +The supported NumPy surface is the pure, sandbox-safe numeric core: `ndarray`, supported numeric dtypes, broadcasting, ufunc-style math and predicates, reductions, indexing, shape helpers, sorting/selection helpers, formatting helpers, and common construction/manipulation routines. Monty intentionally does not expose host-boundary or external-memory APIs such as file loaders/savers, `memmap`, DLPack, `ctypeslib`, or include-path discovery, and it does not aim to implement full submodule families such as `linalg`, `fft`, `random`, `ma`, `testing`, or `typing`. + ### Python To install: diff --git a/crates/monty-datatest/src/main.rs b/crates/monty-datatest/src/main.rs index e5e13f1c4..1ed4f6640 100644 --- a/crates/monty-datatest/src/main.rs +++ b/crates/monty-datatest/src/main.rs @@ -86,6 +86,9 @@ struct TestConfig { /// When true, wrap code in async context for CPython execution. /// Used for tests with top-level await which Monty supports but CPython doesn't. async_mode: bool, + /// When true, skip running this test on CPython. + /// Used for tests that require modules not available in the CPython test harness. + skip_cpython: bool, /// When true, create a temporary directory with a known structure and mount it. /// For Monty: mounted at `/mnt` with `OverlayMemory` mode. /// For CPython: passed as real path. `root` variable injected into both. @@ -172,14 +175,18 @@ fn parse_fixture(content: &str) -> (String, Expectation, TestConfig) { .map(|line| line.trim_start_matches('#').trim()) .collect::>(); - let mount_fs = comment_lines.iter().any(|line| line.starts_with("mount-fs")); + let has_marker = |marker| { + comment_lines + .iter() + .any(|line| line.split_whitespace().next() == Some(marker)) + }; + let mount_fs = has_marker("mount-fs"); let mut config = TestConfig { - iter_mode: comment_lines.iter().any(|line| line.starts_with("call-external")) || mount_fs, - async_mode: comment_lines.iter().any(|line| line.starts_with("run-async")), + iter_mode: has_marker("call-external") || mount_fs, + async_mode: has_marker("run-async"), + skip_cpython: has_marker("skip-cpython"), mount_fs, - skip_cpython_windows: comment_lines - .iter() - .any(|line| line.starts_with("skip-cpython-windows")), + skip_cpython_windows: has_marker("skip-cpython-windows"), ..Default::default() }; // Check for "xfail=" directive @@ -2372,6 +2379,10 @@ fn run_test_cases_cpython(path: &Path) -> Result<(), Box> { .display() .to_string(); + if config.skip_cpython { + return Ok(()); + } + // Skip CPython tests that rely on POSIX path semantics when running on Windows if cfg!(windows) && config.skip_cpython_windows { return Ok(()); diff --git a/crates/monty-js/README.md b/crates/monty-js/README.md index 5681c936b..1365ad72d 100644 --- a/crates/monty-js/README.md +++ b/crates/monty-js/README.md @@ -25,6 +25,26 @@ const m = new Monty('x + y', { inputs: ['x', 'y'] }) const result = m.run({ inputs: { x: 10, y: 20 } }) // returns 30 ``` +## Built-in NumPy Support + +Monty includes a built-in `numpy` module for pure numeric array workloads inside the sandbox. User code can import NumPy-style APIs without loading CPython's C-backed NumPy package or exposing host files, external memory, network access, or other host-boundary features. + +```ts +const m = new Monty(` +import numpy as np + +scores = np.array([[1, 2, 3], [4, 5, 6]]) +weights = np.array([10, 20, 30]) +weighted = scores * weights + +weighted.tolist() +`) + +const result = m.run() // returns [[10, 40, 90], [40, 100, 180]] +``` + +Supported NumPy behavior focuses on Monty's sandbox-safe numeric core: `ndarray`, supported numeric dtypes, broadcasting, ufunc-style math and predicates, reductions, indexing, shape helpers, sorting/selection helpers, formatting helpers, and common construction/manipulation routines. Host-boundary APIs, external-memory APIs, object/string/complex/datetime arrays, and full submodule families such as `linalg`, `fft`, `random`, `ma`, `testing`, and `typing` are intentionally outside this built-in subset. + ## External Functions For synchronous external functions, pass them directly to `run()`: diff --git a/crates/monty-python/README.md b/crates/monty-python/README.md index 8717c8890..8d2168f75 100644 --- a/crates/monty-python/README.md +++ b/crates/monty-python/README.md @@ -36,6 +36,30 @@ print(m.run(inputs={'x': 10, 'y': 5})) #> 50 ``` +### Built-in NumPy Support + +Monty includes a built-in `numpy` module for pure numeric array workloads inside the sandbox. This lets user code import NumPy-style APIs without loading CPython's C-backed NumPy package or exposing host files, external memory, network access, or other host-boundary features. + +```python +import pydantic_monty + +code = """ +import numpy as np + +scores = np.array([[1, 2, 3], [4, 5, 6]]) +weights = np.array([10, 20, 30]) +weighted = scores * weights + +(weighted.tolist(), weighted.sum(), scores.mean(), str(scores.astype('float').dtype)) +""" + +m = pydantic_monty.Monty(code) +print(m.run()) +#> ([[10, 40, 90], [40, 100, 180]], 460, 3.5, 'float64') +``` + +Supported NumPy behavior focuses on Monty's sandbox-safe numeric core: `ndarray`, supported numeric dtypes, broadcasting, ufunc-style math and predicates, reductions, indexing, shape helpers, sorting/selection helpers, formatting helpers, and common construction/manipulation routines. Host-boundary APIs, external-memory APIs, object/string/complex/datetime arrays, and full submodule families such as `linalg`, `fft`, `random`, `ma`, `testing`, and `typing` are intentionally outside this built-in subset. + ### Resource Limits ```python diff --git a/crates/monty/src/bytecode/vm/binary.rs b/crates/monty/src/bytecode/vm/binary.rs index 249d5c231..5dd8b4bc3 100644 --- a/crates/monty/src/bytecode/vm/binary.rs +++ b/crates/monty/src/bytecode/vm/binary.rs @@ -3,10 +3,10 @@ use super::VM; use crate::{ defer_drop, - exception_private::{ExcType, RunError}, - heap::{HeapData, HeapGuard, HeapReadOutput}, + exception_private::{ExcType, RunError, RunResult}, + heap::{Heap, HeapData, HeapGuard, HeapReadOutput}, resource::ResourceTracker, - types::{PyTrait, Set, dict_view::collect_iterable_to_set, set::SetBinaryOp}, + types::{NdArray, PyTrait, Set, dict_view::collect_iterable_to_set, ndarray::NdArrayDtype, set::SetBinaryOp}, value::{BitwiseOp, Value}, }; @@ -23,6 +23,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path: intercept before general dispatch + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::Add, this)? { + this.push(result); + return Ok(()); + } + match lhs.py_add(rhs, this) { Ok(Some(v)) => { this.push(v); @@ -51,6 +57,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::Sub, this)? { + this.push(result); + return Ok(()); + } + if let Some(result) = this.binary_dict_view_op(lhs, rhs, DictViewBinaryOp::Sub)? { this.push(result); return Ok(()); @@ -86,6 +98,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::Mul, this)? { + this.push(result); + return Ok(()); + } + match lhs.py_mult(rhs, this) { Ok(Some(v)) => { this.push(v); @@ -111,6 +129,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::Div, this)? { + this.push(result); + return Ok(()); + } + match lhs.py_div(rhs, this) { Ok(Some(v)) => { this.push(v); @@ -136,6 +160,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::FloorDiv, this)? { + this.push(result); + return Ok(()); + } + match lhs.py_floordiv(rhs, this) { Ok(Some(v)) => { this.push(v); @@ -161,6 +191,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::Mod, this)? { + this.push(result); + return Ok(()); + } + match lhs.py_mod(rhs, this) { Ok(Some(v)) => { this.push(v); @@ -187,6 +223,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_binary(lhs, rhs, NdArrayBinaryOp::Pow, this)? { + this.push(result); + return Ok(()); + } + match lhs.py_pow(rhs, this) { Ok(Some(v)) => { this.push(v); @@ -248,6 +290,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_bitwise(lhs, rhs, NdArrayBitwiseOp::And, this)? { + this.push(result); + return Ok(()); + } + if let Some(result) = this.binary_dict_view_op(lhs, rhs, DictViewBinaryOp::And)? { this.push(result); return Ok(()); @@ -272,6 +320,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_bitwise(lhs, rhs, NdArrayBitwiseOp::Or, this)? { + this.push(result); + return Ok(()); + } + if let Some(result) = this.binary_dict_view_op(lhs, rhs, DictViewBinaryOp::Or)? { this.push(result); return Ok(()); @@ -296,6 +350,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path + if let Some(result) = try_ndarray_bitwise(lhs, rhs, NdArrayBitwiseOp::Xor, this)? { + this.push(result); + return Ok(()); + } + if let Some(result) = this.binary_dict_view_op(lhs, rhs, DictViewBinaryOp::Xor)? { this.push(result); return Ok(()); @@ -314,6 +374,7 @@ impl VM<'_, T> { /// In-place addition (uses py_iadd for mutable containers, falls back to py_add). /// /// For mutable types like lists, `py_iadd` mutates in place and returns true. + /// For ndarray, modifies data in place and returns the same array reference. /// For immutable types, we fall back to regular addition. /// /// Uses lazy type capture: only calls `py_type()` in error paths. @@ -321,6 +382,11 @@ impl VM<'_, T> { /// Note: Cannot use `defer_drop!` for `lhs` here because on successful in-place /// operation, we need to push `lhs` back onto the stack rather than drop it. pub(super) fn inplace_add(&mut self) -> Result<(), RunError> { + // NdArray in-place fast path — try before popping so we can fall through + if try_ndarray_inplace(self, NdArrayInplaceOp::Add)? { + return Ok(()); + } + let this = self; let rhs = this.pop(); @@ -350,14 +416,81 @@ impl VM<'_, T> { /// Binary matrix multiplication (`@` operator). /// - /// Currently not implemented - returns a `NotImplementedError`. - /// Matrix multiplication requires numpy-like array types which Monty doesn't support. + /// Dispatches to `NdArray::matmul` for ndarray operands: + /// - **1D @ 1D**: dot product (scalar result) + /// - **2D @ 2D**: matrix multiplication + /// - **2D @ 1D** / **1D @ 2D**: matrix-vector / vector-matrix product pub(super) fn binary_matmul(&mut self) -> Result<(), RunError> { - let rhs = self.pop(); - let lhs = self.pop(); - lhs.drop_with_heap(self); - rhs.drop_with_heap(self); - Err(ExcType::not_implemented("matrix multiplication (@) is not supported").into()) + let this = self; + + let rhs = this.pop(); + defer_drop!(rhs, this); + let lhs = this.pop(); + defer_drop!(lhs, this); + + // Both operands must be NdArray + let (Some(Value::Ref(lid)), Some(Value::Ref(rid))) = (Some(lhs), Some(rhs)) else { + return Err(ExcType::type_error("matmul requires ndarray operands")); + }; + + let HeapData::NdArray(l) = this.heap.get(*lid) else { + return Err(ExcType::type_error("matmul requires ndarray operands")); + }; + let HeapData::NdArray(r) = this.heap.get(*rid) else { + return Err(ExcType::type_error("matmul requires ndarray operands")); + }; + + let result = l.matmul(r, this.heap)?; + this.push(result); + Ok(()) + } + + /// In-place subtraction for ndarray. Falls back to binary subtraction for other types. + pub(super) fn inplace_sub(&mut self) -> Result<(), RunError> { + if try_ndarray_inplace(self, NdArrayInplaceOp::Sub)? { + return Ok(()); + } + self.binary_sub() + } + + /// In-place multiplication for ndarray. Falls back to binary multiplication for other types. + pub(super) fn inplace_mul(&mut self) -> Result<(), RunError> { + if try_ndarray_inplace(self, NdArrayInplaceOp::Mul)? { + return Ok(()); + } + self.binary_mult() + } + + /// In-place division for ndarray. Falls back to binary division for other types. + pub(super) fn inplace_div(&mut self) -> Result<(), RunError> { + if try_ndarray_inplace(self, NdArrayInplaceOp::Div)? { + return Ok(()); + } + self.binary_div() + } + + /// In-place floor division for ndarray. Falls back to binary floor division for other types. + pub(super) fn inplace_floordiv(&mut self) -> Result<(), RunError> { + if try_ndarray_inplace(self, NdArrayInplaceOp::FloorDiv)? { + return Ok(()); + } + self.binary_floordiv() + } + + /// In-place modulo for ndarray. Falls back to binary modulo for other types. + pub(super) fn inplace_mod(&mut self) -> Result<(), RunError> { + if try_ndarray_inplace(self, NdArrayInplaceOp::Mod)? { + return Ok(()); + } + self.binary_mod() + } + + /// In-place power for ndarray. Falls back to binary power for other types. + pub(super) fn inplace_pow(&mut self) -> Result<(), RunError> { + if try_ndarray_inplace(self, NdArrayInplaceOp::Pow)? { + return Ok(()); + } + self.binary_pow() } /// Implements dict-view set-like operators before falling back to other dispatch. @@ -479,3 +612,500 @@ fn apply_dict_view_binary_op( Ok(result) } + +/// Supported ndarray element-wise binary operations. +#[derive(Debug, Clone, Copy)] +enum NdArrayBinaryOp { + Add, + Sub, + Mul, + Div, + FloorDiv, + Mod, + Pow, +} + +/// Extracts a scalar f64 from a `Value`, if it is a numeric type. +/// +/// Returns `(f64_value, is_float)` — the `is_float` flag indicates whether the Python +/// value was a `float` (as opposed to `int` or `bool`), which is needed for correct +/// dtype promotion in ndarray operations. +fn value_to_f64(v: &Value) -> Option<(f64, bool)> { + match v { + Value::Int(i) => Some((*i as f64, false)), + Value::Float(f) => Some((*f, true)), + Value::Bool(b) => Some((if *b { 1.0 } else { 0.0 }, false)), + _ => None, + } +} + +/// Dispatches an element-wise binary operation between an `NdArray` and a scalar. +fn ndarray_scalar_op( + arr: &NdArray, + scalar: f64, + scalar_is_float: bool, + op: NdArrayBinaryOp, + scalar_on_left: bool, + heap: &Heap, +) -> RunResult { + match (op, scalar_on_left) { + // Commutative operations — direction doesn't matter + (NdArrayBinaryOp::Add, _) => arr.add_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Mul, _) => arr.mul_scalar(scalar, scalar_is_float, heap), + // Non-commutative: scalar on right (arr op scalar) + (NdArrayBinaryOp::Sub, false) => arr.sub_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Div, false) => arr.div_scalar(scalar, heap), + (NdArrayBinaryOp::FloorDiv, false) => arr.floordiv_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Mod, false) => arr.modulo_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Pow, false) => arr.pow_scalar(scalar, scalar_is_float, heap), + // Non-commutative: scalar on left (scalar op arr) + (NdArrayBinaryOp::Sub, true) => arr.rsub_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Div, true) => arr.rdiv_scalar(scalar, heap), + (NdArrayBinaryOp::FloorDiv, true) => arr.rfloordiv_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Mod, true) => arr.rmod_scalar(scalar, scalar_is_float, heap), + (NdArrayBinaryOp::Pow, true) => arr.rpow_scalar(scalar, scalar_is_float, heap), + } +} + +/// Dispatches an element-wise binary operation between two `NdArray`s. +fn ndarray_array_op( + lhs: &NdArray, + rhs: &NdArray, + op: NdArrayBinaryOp, + heap: &Heap, +) -> RunResult { + match op { + NdArrayBinaryOp::Add => lhs.add(rhs, heap), + NdArrayBinaryOp::Sub => lhs.sub(rhs, heap), + NdArrayBinaryOp::Mul => lhs.mul(rhs, heap), + NdArrayBinaryOp::Div => lhs.div(rhs, heap), + NdArrayBinaryOp::FloorDiv => lhs.floordiv(rhs, heap), + NdArrayBinaryOp::Mod => lhs.modulo(rhs, heap), + NdArrayBinaryOp::Pow => lhs.pow(rhs, heap), + } +} + +/// Tries to dispatch an ndarray binary operation. +/// +/// Returns `Ok(Some(value))` if either operand is an ndarray and the operation succeeded, +/// `Ok(None)` if neither operand is an ndarray (caller should fall through to normal dispatch), +/// or `Err` if the operation failed (e.g., shape mismatch). +fn try_ndarray_binary( + lhs: &Value, + rhs: &Value, + op: NdArrayBinaryOp, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult> { + // Both operands must involve at least one ndarray + let lhs_id = if let Value::Ref(id) = lhs { Some(*id) } else { None }; + let rhs_id = if let Value::Ref(id) = rhs { Some(*id) } else { None }; + + // Case 1: NdArray op NdArray + if let (Some(lid), Some(rid)) = (lhs_id, rhs_id) { + let lhs_is_ndarray = matches!(vm.heap.get(lid), HeapData::NdArray(_)); + let rhs_is_ndarray = matches!(vm.heap.get(rid), HeapData::NdArray(_)); + if lhs_is_ndarray && rhs_is_ndarray { + let HeapData::NdArray(l) = vm.heap.get(lid) else { + unreachable!() + }; + let HeapData::NdArray(r) = vm.heap.get(rid) else { + unreachable!() + }; + return ndarray_array_op(l, r, op, vm.heap).map(Some); + } + } + + // Case 2: NdArray op scalar + if let Some(lid) = lhs_id + && let HeapData::NdArray(arr) = vm.heap.get(lid) + && let Some((scalar, is_float)) = value_to_f64(rhs) + { + return ndarray_scalar_op(arr, scalar, is_float, op, false, vm.heap).map(Some); + } + + // Case 3: scalar op NdArray + if let Some(rid) = rhs_id + && let HeapData::NdArray(arr) = vm.heap.get(rid) + && let Some((scalar, is_float)) = value_to_f64(lhs) + { + return ndarray_scalar_op(arr, scalar, is_float, op, true, vm.heap).map(Some); + } + + Ok(None) +} + +/// Supported ndarray element-wise bitwise operations. +#[derive(Debug, Clone, Copy)] +enum NdArrayBitwiseOp { + And, + Or, + Xor, +} + +/// Extracts an integer value from a `Value` for bitwise ndarray operations. +/// +/// Returns `Some(i64)` for `Int`, `Bool` (True=1, False=0). Returns `None` for non-integer types. +fn value_to_i64(v: &Value) -> Option { + match v { + Value::Int(i) => Some(*i), + Value::Bool(b) => Some(i64::from(*b)), + _ => None, + } +} + +/// Tries to dispatch an ndarray bitwise operation. +/// +/// Returns `Ok(Some(value))` if either operand is an ndarray and the operation succeeded, +/// `Ok(None)` if neither operand is an ndarray (caller should fall through to normal dispatch), +/// or `Err` if the operation failed (e.g., float dtype, shape mismatch). +fn try_ndarray_bitwise( + lhs: &Value, + rhs: &Value, + op: NdArrayBitwiseOp, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult> { + let lhs_id = if let Value::Ref(id) = lhs { Some(*id) } else { None }; + let rhs_id = if let Value::Ref(id) = rhs { Some(*id) } else { None }; + + // Case 1: NdArray op NdArray + if let (Some(lid), Some(rid)) = (lhs_id, rhs_id) { + let lhs_is_ndarray = matches!(vm.heap.get(lid), HeapData::NdArray(_)); + let rhs_is_ndarray = matches!(vm.heap.get(rid), HeapData::NdArray(_)); + if lhs_is_ndarray && rhs_is_ndarray { + let HeapData::NdArray(l) = vm.heap.get(lid) else { + unreachable!() + }; + let HeapData::NdArray(r) = vm.heap.get(rid) else { + unreachable!() + }; + let result = match op { + NdArrayBitwiseOp::And => l.bitand(r, vm.heap), + NdArrayBitwiseOp::Or => l.bitor(r, vm.heap), + NdArrayBitwiseOp::Xor => l.bitxor(r, vm.heap), + }; + return result.map(Some); + } + } + + // Case 2: NdArray op scalar + if let Some(lid) = lhs_id + && let HeapData::NdArray(arr) = vm.heap.get(lid) + && let Some(scalar) = value_to_i64(rhs) + { + let result = match op { + NdArrayBitwiseOp::And => arr.bitand_scalar(scalar, vm.heap), + NdArrayBitwiseOp::Or => arr.bitor_scalar(scalar, vm.heap), + NdArrayBitwiseOp::Xor => arr.bitxor_scalar(scalar, vm.heap), + }; + return result.map(Some); + } + + // Case 3: scalar op NdArray (commutative) + if let Some(rid) = rhs_id + && let HeapData::NdArray(arr) = vm.heap.get(rid) + && let Some(scalar) = value_to_i64(lhs) + { + let result = match op { + NdArrayBitwiseOp::And => arr.bitand_scalar(scalar, vm.heap), + NdArrayBitwiseOp::Or => arr.bitor_scalar(scalar, vm.heap), + NdArrayBitwiseOp::Xor => arr.bitxor_scalar(scalar, vm.heap), + }; + return result.map(Some); + } + + Ok(None) +} + +/// Supported ndarray in-place operations. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum NdArrayInplaceOp { + Add, + Sub, + Mul, + Div, + FloorDiv, + Mod, + Pow, +} + +/// Tries to perform an in-place ndarray operation by peeking at the stack. +/// +/// If the second-from-top value is an ndarray and the top value is a compatible +/// operand (scalar or same-shape ndarray), modifies the array data in-place, +/// drops the rhs, and leaves the lhs on the stack. Returns `Ok(true)` on success. +/// +/// Returns `Ok(false)` if neither operand is an ndarray, leaving the stack unchanged +/// so the caller can fall through to the normal binary operation. +fn try_ndarray_inplace(vm: &mut VM<'_, impl ResourceTracker>, op: NdArrayInplaceOp) -> RunResult { + // Peek at the two top-of-stack values without popping + let lhs_ref = vm.peek_at(1); + + // Only proceed if lhs is an NdArray + let Value::Ref(lid) = lhs_ref else { + return Ok(false); + }; + if !matches!(vm.heap.get(*lid), HeapData::NdArray(_)) { + return Ok(false); + } + + let lid = *lid; + + // Pop rhs into a HeapGuard so we can reclaim it on the fallback path. + let mut rhs_guard = HeapGuard::new(vm.pop(), vm); + let (rhs, vm_ref) = rhs_guard.as_parts_mut(); + + // Try scalar rhs + if let Some((scalar, scalar_is_float)) = value_to_f64(rhs) { + let lhs_dtype = { + let HeapData::NdArray(l) = vm_ref.heap.get(lid) else { + unreachable!() + }; + l.dtype() + }; + let rhs_dtype = if scalar_is_float { + NdArrayDtype::Float64 + } else { + NdArrayDtype::Int64 + }; + validate_inplace_cast(lhs_dtype, rhs_dtype, &[scalar], op)?; + + // Consume the guard — scalar types (Int/Float/Bool) don't need heap cleanup + let (rhs, vm) = rhs_guard.into_parts(); + rhs.drop_with_heap(vm); + let lhs = vm.pop(); + // Use HeapRead API for mutation + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(lid) else { + unreachable!() + }; + apply_inplace_scalar(arr_read.get_mut(vm.heap), scalar, scalar_is_float, op); + drop(arr_read); + vm.push(lhs); + return Ok(true); + } + + // Try ndarray rhs + if let Value::Ref(rid) = rhs { + let rid = *rid; + let shape_matches = { + let HeapData::NdArray(r) = vm_ref.heap.get(rid) else { + // rhs is a Ref but not an NdArray — put it back + let (rhs, vm) = rhs_guard.into_parts(); + vm.push(rhs); + return Ok(false); + }; + let HeapData::NdArray(l) = vm_ref.heap.get(lid) else { + unreachable!() + }; + r.shape() == l.shape() + }; + if shape_matches { + let (rhs_data, rhs_dtype): (Vec, NdArrayDtype) = { + let HeapData::NdArray(r) = vm_ref.heap.get(rid) else { + unreachable!() + }; + (r.data().to_vec(), r.dtype()) + }; + let lhs_dtype = { + let HeapData::NdArray(l) = vm_ref.heap.get(lid) else { + unreachable!() + }; + l.dtype() + }; + validate_inplace_cast(lhs_dtype, rhs_dtype, &rhs_data, op)?; + + let (rhs, vm) = rhs_guard.into_parts(); + rhs.drop_with_heap(vm); + let lhs = vm.pop(); + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(lid) else { + unreachable!() + }; + apply_inplace_array(arr_read.get_mut(vm.heap), &rhs_data, rhs_dtype, op); + drop(arr_read); + vm.push(lhs); + return Ok(true); + } + } + + // Not a compatible ndarray operation — put rhs back on the stack + let (rhs, vm) = rhs_guard.into_parts(); + vm.push(rhs); + Ok(false) +} + +/// Applies an in-place scalar operation to an ndarray's data. +fn apply_inplace_scalar(arr: &mut NdArray, scalar: f64, scalar_is_float: bool, op: NdArrayInplaceOp) { + arr.dtype = inplace_scalar_result_dtype(arr.dtype, scalar_is_float, op); + match op { + NdArrayInplaceOp::Add => { + for v in &mut arr.data { + *v += scalar; + } + } + NdArrayInplaceOp::Sub => { + for v in &mut arr.data { + *v -= scalar; + } + } + NdArrayInplaceOp::Mul => { + for v in &mut arr.data { + *v *= scalar; + } + } + NdArrayInplaceOp::Div => { + for v in &mut arr.data { + *v /= scalar; + } + } + NdArrayInplaceOp::FloorDiv => { + for v in &mut arr.data { + *v = (*v / scalar).floor(); + } + } + NdArrayInplaceOp::Mod => { + for v in &mut arr.data { + *v %= scalar; + } + } + NdArrayInplaceOp::Pow => { + for v in &mut arr.data { + *v = v.powf(scalar); + } + } + } +} + +/// Applies an in-place array-to-array operation to an ndarray's data. +fn apply_inplace_array(arr: &mut NdArray, rhs_data: &[f64], rhs_dtype: NdArrayDtype, op: NdArrayInplaceOp) { + arr.dtype = inplace_array_result_dtype(arr.dtype, rhs_dtype, op); + match op { + NdArrayInplaceOp::Add => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v += rv; + } + } + NdArrayInplaceOp::Sub => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v -= rv; + } + } + NdArrayInplaceOp::Mul => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v *= rv; + } + } + NdArrayInplaceOp::Div => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v /= rv; + } + } + NdArrayInplaceOp::FloorDiv => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v = (*v / rv).floor(); + } + } + NdArrayInplaceOp::Mod => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v %= rv; + } + } + NdArrayInplaceOp::Pow => { + for (v, rv) in arr.data.iter_mut().zip(rhs_data.iter()) { + *v = v.powf(*rv); + } + } + } +} + +/// Validates NumPy-style casting rules before mutating an ndarray in place. +/// +/// Unlike regular binary ndarray operations, NumPy in-place ufuncs do not +/// silently promote the left-hand array dtype. If the ufunc output cannot be +/// cast back to the existing dtype using NumPy's same-kind rule, Monty raises +/// before changing either the data or dtype. +fn validate_inplace_cast( + lhs_dtype: NdArrayDtype, + rhs_dtype: NdArrayDtype, + rhs_data: &[f64], + op: NdArrayInplaceOp, +) -> RunResult<()> { + if lhs_dtype != NdArrayDtype::Float64 + && op == NdArrayInplaceOp::Pow + && rhs_dtype != NdArrayDtype::Float64 + && rhs_data.iter().any(|&value| value < 0.0) + { + return Err(ExcType::value_error( + "Integers to negative integer powers are not allowed.", + )); + } + + let output_dtype = inplace_array_result_dtype(lhs_dtype, rhs_dtype, op); + if output_dtype == lhs_dtype { + Ok(()) + } else { + Err(ExcType::type_error(format!( + "Cannot cast ufunc '{}' output from dtype('{}') to dtype('{}') with casting rule 'same_kind'", + op.ufunc_name(), + output_dtype, + lhs_dtype + ))) + } +} + +/// Determines the dtype after a mutating scalar ndarray operation. +/// +/// Monty's ndarray storage is f64-backed, so in-place operations that can produce +/// fractional values must update the dtype instead of leaving fractional storage +/// behind an integer or boolean display contract. +fn inplace_scalar_result_dtype(lhs_dtype: NdArrayDtype, scalar_is_float: bool, op: NdArrayInplaceOp) -> NdArrayDtype { + match op { + NdArrayInplaceOp::Div => NdArrayDtype::Float64, + NdArrayInplaceOp::Add + | NdArrayInplaceOp::Sub + | NdArrayInplaceOp::Mul + | NdArrayInplaceOp::FloorDiv + | NdArrayInplaceOp::Mod + | NdArrayInplaceOp::Pow => { + if lhs_dtype == NdArrayDtype::Float64 || scalar_is_float { + NdArrayDtype::Float64 + } else if lhs_dtype == NdArrayDtype::Bool { + NdArrayDtype::Int64 + } else { + lhs_dtype + } + } + } +} + +/// Determines the dtype after a mutating array-to-array ndarray operation. +fn inplace_array_result_dtype(lhs_dtype: NdArrayDtype, rhs_dtype: NdArrayDtype, op: NdArrayInplaceOp) -> NdArrayDtype { + match op { + NdArrayInplaceOp::Div => NdArrayDtype::Float64, + NdArrayInplaceOp::Add + | NdArrayInplaceOp::Sub + | NdArrayInplaceOp::Mul + | NdArrayInplaceOp::FloorDiv + | NdArrayInplaceOp::Mod + | NdArrayInplaceOp::Pow => { + if lhs_dtype == NdArrayDtype::Float64 || rhs_dtype == NdArrayDtype::Float64 { + NdArrayDtype::Float64 + } else { + NdArrayDtype::Int64 + } + } + } +} + +impl NdArrayInplaceOp { + /// Returns the NumPy ufunc name used in casting errors for this in-place operation. + fn ufunc_name(self) -> &'static str { + match self { + Self::Add => "add", + Self::Sub => "subtract", + Self::Mul => "multiply", + Self::Div => "divide", + Self::FloorDiv => "floor_divide", + Self::Mod => "remainder", + Self::Pow => "power", + } + } +} diff --git a/crates/monty/src/bytecode/vm/compare.rs b/crates/monty/src/bytecode/vm/compare.rs index 4d8dabb41..2dfb609b0 100644 --- a/crates/monty/src/bytecode/vm/compare.rs +++ b/crates/monty/src/bytecode/vm/compare.rs @@ -5,9 +5,10 @@ use std::cmp::Ordering; use super::VM; use crate::{ defer_drop, - exception_private::{ExcType, RunError}, + exception_private::{ExcType, RunError, RunResult}, + heap::{Heap, HeapData}, resource::ResourceTracker, - types::{LongInt, PyTrait}, + types::{LongInt, NdArray, PyTrait}, value::Value, }; @@ -21,6 +22,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path: element-wise comparison returning boolean array + if let Some(result) = try_ndarray_cmp(lhs, rhs, NdArrayCmpOp::Eq, this)? { + this.push(result); + return Ok(()); + } + let result = lhs.py_eq(rhs, this)?; this.push(Value::Bool(result)); Ok(()) @@ -35,6 +42,12 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path: element-wise comparison returning boolean array + if let Some(result) = try_ndarray_cmp(lhs, rhs, NdArrayCmpOp::Ne, this)? { + this.push(result); + return Ok(()); + } + let result = !lhs.py_eq(rhs, this)?; this.push(Value::Bool(result)); Ok(()) @@ -43,7 +56,7 @@ impl VM<'_, T> { /// Ordering comparison with a predicate. pub(super) fn compare_ord(&mut self, check: F) -> Result<(), RunError> where - F: FnOnce(Ordering) -> bool, + F: Fn(Ordering) -> bool, { let this = self; @@ -52,6 +65,15 @@ impl VM<'_, T> { let lhs = this.pop(); defer_drop!(lhs, this); + // NdArray fast path: detect ordering comparisons involving ndarrays. + // We need to determine the specific comparison from the check predicate by testing it. + if let Some(ndarray_op) = ndarray_cmp_from_ord_check(&check) + && let Some(result) = try_ndarray_cmp(lhs, rhs, ndarray_op, this)? + { + this.push(result); + return Ok(()); + } + let result = lhs.py_cmp(rhs, this)?.is_some_and(check); this.push(Value::Bool(result)); Ok(()) @@ -147,3 +169,133 @@ impl VM<'_, T> { } } } + +/// Supported ndarray element-wise comparison operations. +#[derive(Debug, Clone, Copy)] +enum NdArrayCmpOp { + Eq, + Ne, + Gt, + Lt, + Gte, + Lte, +} + +/// Extracts a scalar f64 from a `Value`, if it is a numeric type. +/// +/// Comparisons always return `Bool` dtype so the float flag is unused, but we match +/// the return type from `binary.rs` for consistency. +fn value_to_f64(v: &Value) -> Option<(f64, bool)> { + match v { + Value::Int(i) => Some((*i as f64, false)), + Value::Float(f) => Some((*f, true)), + Value::Bool(b) => Some((if *b { 1.0 } else { 0.0 }, false)), + _ => None, + } +} + +/// Determines the ndarray comparison op from a `compare_ord` predicate. +/// +/// Tests the predicate with `Less`, `Equal`, and `Greater` to infer which +/// comparison is being performed (e.g., `<`, `<=`, `>`, `>=`). +fn ndarray_cmp_from_ord_check(check: &impl Fn(Ordering) -> bool) -> Option { + let lt = check(Ordering::Less); + let eq = check(Ordering::Equal); + let gt = check(Ordering::Greater); + match (lt, eq, gt) { + (true, false, false) => Some(NdArrayCmpOp::Lt), + (true, true, false) => Some(NdArrayCmpOp::Lte), + (false, false, true) => Some(NdArrayCmpOp::Gt), + (false, true, true) => Some(NdArrayCmpOp::Gte), + _ => None, + } +} + +/// Dispatches an ndarray comparison between an ndarray and a scalar. +fn ndarray_scalar_cmp( + arr: &NdArray, + scalar: f64, + op: NdArrayCmpOp, + heap: &Heap, +) -> RunResult { + match op { + NdArrayCmpOp::Gt => arr.gt_scalar(scalar, heap), + NdArrayCmpOp::Lt => arr.lt_scalar(scalar, heap), + NdArrayCmpOp::Eq => arr.eq_scalar(scalar, heap), + NdArrayCmpOp::Gte => arr.gte_scalar(scalar, heap), + NdArrayCmpOp::Lte => arr.lte_scalar(scalar, heap), + NdArrayCmpOp::Ne => arr.ne_scalar(scalar, heap), + } +} + +/// Dispatches an ndarray comparison between two ndarrays. +fn ndarray_array_cmp( + lhs: &NdArray, + rhs: &NdArray, + op: NdArrayCmpOp, + heap: &Heap, +) -> RunResult { + match op { + NdArrayCmpOp::Gt => lhs.gt(rhs, heap), + NdArrayCmpOp::Lt => lhs.lt(rhs, heap), + NdArrayCmpOp::Eq => lhs.eq_array(rhs, heap), + NdArrayCmpOp::Gte => lhs.gte(rhs, heap), + NdArrayCmpOp::Lte => lhs.lte(rhs, heap), + NdArrayCmpOp::Ne => lhs.ne_array(rhs, heap), + } +} + +/// Tries to dispatch an ndarray comparison operation. +/// +/// Returns `Ok(Some(value))` if either operand is an ndarray, `Ok(None)` if neither is. +fn try_ndarray_cmp( + lhs: &Value, + rhs: &Value, + op: NdArrayCmpOp, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult> { + let lhs_id = if let Value::Ref(id) = lhs { Some(*id) } else { None }; + let rhs_id = if let Value::Ref(id) = rhs { Some(*id) } else { None }; + + // Case 1: NdArray cmp NdArray + if let (Some(lid), Some(rid)) = (lhs_id, rhs_id) { + let lhs_is_ndarray = matches!(vm.heap.get(lid), HeapData::NdArray(_)); + let rhs_is_ndarray = matches!(vm.heap.get(rid), HeapData::NdArray(_)); + if lhs_is_ndarray && rhs_is_ndarray { + let HeapData::NdArray(l) = vm.heap.get(lid) else { + unreachable!() + }; + let HeapData::NdArray(r) = vm.heap.get(rid) else { + unreachable!() + }; + return ndarray_array_cmp(l, r, op, vm.heap).map(Some); + } + } + + // Case 2: NdArray cmp scalar + if let Some(lid) = lhs_id + && let HeapData::NdArray(arr) = vm.heap.get(lid) + && let Some((scalar, _)) = value_to_f64(rhs) + { + return ndarray_scalar_cmp(arr, scalar, op, vm.heap).map(Some); + } + + // Case 3: scalar cmp NdArray (reverse the comparison) + if let Some(rid) = rhs_id + && let HeapData::NdArray(arr) = vm.heap.get(rid) + && let Some((scalar, _)) = value_to_f64(lhs) + { + // Reverse: `5 > arr` becomes `arr < 5` + let reversed_op = match op { + NdArrayCmpOp::Gt => NdArrayCmpOp::Lt, + NdArrayCmpOp::Lt => NdArrayCmpOp::Gt, + NdArrayCmpOp::Gte => NdArrayCmpOp::Lte, + NdArrayCmpOp::Lte => NdArrayCmpOp::Gte, + NdArrayCmpOp::Eq => NdArrayCmpOp::Eq, + NdArrayCmpOp::Ne => NdArrayCmpOp::Ne, + }; + return ndarray_scalar_cmp(arr, scalar, reversed_op, vm.heap).map(Some); + } + + Ok(None) +} diff --git a/crates/monty/src/bytecode/vm/mod.rs b/crates/monty/src/bytecode/vm/mod.rs index 38c826143..5ac3f2b99 100644 --- a/crates/monty/src/bytecode/vm/mod.rs +++ b/crates/monty/src/bytecode/vm/mod.rs @@ -964,6 +964,16 @@ impl<'h, T: ResourceTracker> VM<'h, T> { Err(e) => catch_sync!(self, cached_frame, RunError::from(e)), } } + HeapData::NdArray(arr) => match arr.neg(self.heap) { + Ok(v) => { + value.drop_with_heap(self); + self.push(v); + } + Err(e) => { + value.drop_with_heap(self); + catch_sync!(self, cached_frame, e); + } + }, HeapData::TimeDelta(td) => { let negated = timedelta::from_total_microseconds(-timedelta::total_microseconds(td)); value.drop_with_heap(self); @@ -1018,18 +1028,31 @@ impl<'h, T: ResourceTracker> VM<'h, T> { Value::Int(n) => self.push(Value::Int(!n)), Value::Bool(b) => self.push(Value::Int(!i64::from(b))), Value::Ref(id) => { - if let HeapData::LongInt(li) = self.heap.get(id) { - // LongInt bitwise NOT: ~x = -(x + 1) - let inverted = -(li.inner() + 1i32); - value.drop_with_heap(self); - match LongInt::new(inverted).into_value(self.heap) { - Ok(v) => self.push(v), - Err(e) => catch_sync!(self, cached_frame, RunError::from(e)), + match self.heap.get(id) { + HeapData::LongInt(li) => { + // LongInt bitwise NOT: ~x = -(x + 1) + let inverted = -(li.inner() + 1i32); + value.drop_with_heap(self); + match LongInt::new(inverted).into_value(self.heap) { + Ok(v) => self.push(v), + Err(e) => catch_sync!(self, cached_frame, RunError::from(e)), + } + } + HeapData::NdArray(arr) => match arr.invert(self.heap) { + Ok(v) => { + value.drop_with_heap(self); + self.push(v); + } + Err(e) => { + value.drop_with_heap(self); + catch_sync!(self, cached_frame, e); + } + }, + _ => { + let value_type = value.py_type(self); + value.drop_with_heap(self); + catch_sync!(self, cached_frame, ExcType::unary_type_error("~", value_type)); } - } else { - let value_type = value.py_type(self); - value.drop_with_heap(self); - catch_sync!(self, cached_frame, ExcType::unary_type_error("~", value_type)); } } _ => { @@ -1041,13 +1064,13 @@ impl<'h, T: ResourceTracker> VM<'h, T> { } // In-place Operations - route through exception handling Opcode::InplaceAdd => try_catch_sync!(self, cached_frame, self.inplace_add()), - // Other in-place ops use the same logic as binary ops for now - Opcode::InplaceSub => try_catch_sync!(self, cached_frame, self.binary_sub()), - Opcode::InplaceMul => try_catch_sync!(self, cached_frame, self.binary_mult()), - Opcode::InplaceDiv => try_catch_sync!(self, cached_frame, self.binary_div()), - Opcode::InplaceFloorDiv => try_catch_sync!(self, cached_frame, self.binary_floordiv()), - Opcode::InplaceMod => try_catch_sync!(self, cached_frame, self.binary_mod()), - Opcode::InplacePow => try_catch_sync!(self, cached_frame, self.binary_pow()), + // In-place ops: ndarray modifies in-place, other types fall through to binary + Opcode::InplaceSub => try_catch_sync!(self, cached_frame, self.inplace_sub()), + Opcode::InplaceMul => try_catch_sync!(self, cached_frame, self.inplace_mul()), + Opcode::InplaceDiv => try_catch_sync!(self, cached_frame, self.inplace_div()), + Opcode::InplaceFloorDiv => try_catch_sync!(self, cached_frame, self.inplace_floordiv()), + Opcode::InplaceMod => try_catch_sync!(self, cached_frame, self.inplace_mod()), + Opcode::InplacePow => try_catch_sync!(self, cached_frame, self.inplace_pow()), Opcode::InplaceAnd => { try_catch_sync!(self, cached_frame, self.binary_bitwise(BitwiseOp::And)); } @@ -1608,6 +1631,17 @@ impl<'h, T: ResourceTracker> VM<'h, T> { self.stack.last().expect("stack underflow") } + /// Peeks at a value `offset` positions from the top of the stack. + /// + /// `peek_at(0)` is equivalent to `peek()` (top of stack). + /// `peek_at(1)` returns the second-from-top value. + #[inline] + pub(super) fn peek_at(&self, offset: usize) -> &Value { + let len = self.stack.len(); + assert!(offset < len, "stack underflow: peek_at({offset}) with stack size {len}"); + &self.stack[len - 1 - offset] + } + /// Pops n values from the stack in reverse order (first popped is last in vec). pub(super) fn pop_n(&mut self, n: usize) -> Vec { let start = self.stack.len() - n; diff --git a/crates/monty/src/heap.rs b/crates/monty/src/heap.rs index 22a44dfb2..4ecac99a4 100644 --- a/crates/monty/src/heap.rs +++ b/crates/monty/src/heap.rs @@ -23,8 +23,8 @@ use crate::{ resource::{ResourceError, ResourceTracker}, types::{ Bytes, Dataclass, Dict, DictItemsView, DictKeysView, DictValuesView, FrozenSet, List, LongInt, Module, - MontyIter, NamedTuple, Path, PyTrait, Range, ReMatch, RePattern, Set, Slice, Str, TimeZone, Tuple, date, - datetime, timedelta, timezone, + MontyIter, NamedTuple, NdArray, Path, PyTrait, Range, ReMatch, RePattern, Set, Slice, Str, TimeZone, Tuple, + date, datetime, timedelta, timezone, }, value::Value, }; @@ -277,6 +277,7 @@ impl<'a, T: ResourceTracker> HeapReader<'a, T> { HeapData::Path(path) => HeapReadOutput::Path(heap_read(base, path, readers)), HeapData::RePattern(re_pattern) => HeapReadOutput::RePattern(heap_read_boxed(re_pattern, readers)), HeapData::ReMatch(re_match) => HeapReadOutput::ReMatch(heap_read(base, re_match, readers)), + HeapData::NdArray(ndarray) => HeapReadOutput::NdArray(heap_read(base, ndarray, readers)), HeapData::Date(d) => HeapReadOutput::Date(heap_read(base, d, readers)), HeapData::DateTime(d) => HeapReadOutput::DateTime(heap_read(base, d, readers)), HeapData::TimeDelta(d) => HeapReadOutput::TimeDelta(heap_read(base, d, readers)), @@ -362,6 +363,7 @@ pub enum HeapReadOutput<'a> { Path(HeapRead<'a, Path>), RePattern(HeapRead<'a, RePattern>), ReMatch(HeapRead<'a, ReMatch>), + NdArray(HeapRead<'a, NdArray>), Date(HeapRead<'a, date::Date>), DateTime(HeapRead<'a, datetime::DateTime>), TimeDelta(HeapRead<'a, timedelta::TimeDelta>), diff --git a/crates/monty/src/heap_data.rs b/crates/monty/src/heap_data.rs index aed99d9c6..7422ccc35 100644 --- a/crates/monty/src/heap_data.rs +++ b/crates/monty/src/heap_data.rs @@ -21,8 +21,8 @@ use crate::{ intern::FunctionId, types::{ Bytes, Dataclass, Dict, DictItemsView, DictKeysView, DictValuesView, FrozenSet, List, LongInt, Module, - MontyIter, NamedTuple, Path, PyTrait, Range, ReMatch, RePattern, Set, Slice, Str, Tuple, Type, date, datetime, - timedelta, timezone, + MontyIter, NamedTuple, NdArray, Path, PyTrait, Range, ReMatch, RePattern, Set, Slice, Str, Tuple, Type, date, + datetime, timedelta, timezone, }, value::{EitherStr, Value}, }; @@ -116,6 +116,10 @@ pub(crate) enum HeapData { /// Contains the matched text, capture groups, positions, and input string. /// Leaf type: no heap references, not GC-tracked. ReMatch(ReMatch), + /// A numpy ndarray (multi-dimensional array of f64 values). + /// + /// Leaf type: stores only f64 data, no heap references, not GC-tracked. + NdArray(NdArray), /// Reference to an external function whose name was not found in the intern table. /// /// Created when the host resolves a `NameLookup` to a callable whose name does not @@ -203,6 +207,7 @@ impl HeapData { Self::Path(_) => Type::Path, Self::RePattern(_) => Type::RePattern, Self::ReMatch(_) => Type::ReMatch, + Self::NdArray(_) => Type::NdArray, Self::Date(_) => Type::Date, Self::DateTime(_) => Type::DateTime, Self::TimeDelta(_) => Type::TimeDelta, @@ -238,6 +243,7 @@ impl HeapData { Self::Path(p) => p.py_estimate_size(), Self::ReMatch(m) => m.py_estimate_size(), Self::RePattern(p) => p.py_estimate_size(), + Self::NdArray(a) => a.py_estimate_size(), Self::ExtFunction(s) => mem::size_of::() + s.len(), Self::Date(d) => d.py_estimate_size(), Self::DateTime(d) => d.py_estimate_size(), @@ -415,6 +421,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { Self::Path(p) => p.py_bool(vm), Self::ReMatch(m) => m.py_bool(vm), Self::RePattern(p) => p.py_bool(vm), + Self::NdArray(a) => a.py_bool(vm), Self::TimeDelta(td) => td.py_bool(vm), Self::Date(_) | Self::DateTime(_) | Self::TimeZone(_) => true, } @@ -443,6 +450,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { HeapReadOutput::Module(m) => Ok(m.py_call_attr(self_id, vm, attr, args)?), HeapReadOutput::ReMatch(m) => Ok(m.py_call_attr(self_id, vm, attr, args)?), HeapReadOutput::RePattern(p) => Ok(p.py_call_attr(self_id, vm, attr, args)?), + HeapReadOutput::NdArray(a) => Ok(a.py_call_attr(self_id, vm, attr, args)?), HeapReadOutput::TimeDelta(td) => Ok(td.py_call_attr(self_id, vm, attr, args)?), HeapReadOutput::Date(d) => Ok(d.py_call_attr(self_id, vm, attr, args)?), HeapReadOutput::DateTime(dt) => Ok(dt.py_call_attr(self_id, vm, attr, args)?), @@ -481,6 +489,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { Self::Path(p) => p.py_type(vm), Self::ReMatch(re) => re.py_type(vm), Self::RePattern(p) => p.py_type(vm), + Self::NdArray(a) => a.py_type(vm), Self::Date(d) => d.py_type(vm), Self::DateTime(d) => d.py_type(vm), Self::TimeDelta(d) => d.py_type(vm), @@ -506,6 +515,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { Self::Dataclass(dc) => dc.py_len(vm), Self::ReMatch(m) => m.py_len(vm), Self::RePattern(p) => p.py_len(vm), + Self::NdArray(a) => a.py_len(vm), // Types without length — return None _ => None, } @@ -585,7 +595,8 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { (HeapReadOutput::TimeDelta(a), HeapReadOutput::TimeDelta(b)) => a.py_eq(b, vm), (HeapReadOutput::TimeZone(a), HeapReadOutput::TimeZone(b)) => a.py_eq(b, vm), // Identity-only types (handled by HeapId comparison above) - (HeapReadOutput::ReMatch(_), HeapReadOutput::ReMatch(_)) + (HeapReadOutput::NdArray(_), HeapReadOutput::NdArray(_)) + | (HeapReadOutput::ReMatch(_), HeapReadOutput::ReMatch(_)) | (HeapReadOutput::Cell(_), HeapReadOutput::Cell(_)) | (HeapReadOutput::Exception(_), HeapReadOutput::Exception(_)) | (HeapReadOutput::Iter(_), HeapReadOutput::Iter(_)) @@ -698,6 +709,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { Self::Path(p) => p.py_repr_fmt(f, vm, heap_ids), Self::ReMatch(m) => m.py_repr_fmt(f, vm, heap_ids), Self::RePattern(p) => p.py_repr_fmt(f, vm, heap_ids), + Self::NdArray(a) => a.py_repr_fmt(f, vm, heap_ids), Self::ExtFunction(name) => Ok(write!(f, "", name.get(vm.heap))?), Self::Date(d) => d.py_repr_fmt(f, vm, heap_ids), Self::DateTime(d) => d.py_repr_fmt(f, vm, heap_ids), @@ -861,6 +873,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { Self::Dict(d) => d.py_getitem(key, vm), Self::Range(r) => r.py_getitem(key, vm), Self::ReMatch(m) => m.py_getitem(key, vm), + Self::NdArray(a) => a.py_getitem(key, vm), _ => Err(ExcType::type_error_not_sub(self.py_type(vm))), } } @@ -869,6 +882,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { match self { Self::List(l) => l.py_setitem(key, value, vm), Self::Dict(d) => d.py_setitem(key, value, vm), + Self::NdArray(a) => a.py_setitem(key, value, vm), _ => { key.drop_with_heap(vm); value.drop_with_heap(vm); @@ -895,6 +909,7 @@ impl<'h> PyTrait<'h> for HeapReadOutput<'h> { Self::Dataclass(dc) => dc.py_getattr(attr, vm), Self::ReMatch(m) => m.py_getattr(attr, vm), Self::RePattern(p) => p.py_getattr(attr, vm), + Self::NdArray(a) => a.py_getattr(attr, vm), Self::Module(m) => Ok(m.py_getattr(attr, vm)), Self::Exception(e) => e.py_getattr(attr, vm), Self::Path(p) => p.py_getattr(attr, vm), diff --git a/crates/monty/src/intern.rs b/crates/monty/src/intern.rs index de1b9880f..a4fa87575 100644 --- a/crates/monty/src/intern.rs +++ b/crates/monty/src/intern.rs @@ -381,6 +381,7 @@ pub enum StaticStrings { // math module strings Math, // Rounding + Round, Floor, Ceil, Trunc, @@ -600,6 +601,1127 @@ pub enum StaticStrings { /// `match.groupdict()` method Groupdict, + // ========================== + // numpy module strings + /// Module name for `import numpy`. + Numpy, + /// `numpy.array()` function + #[strum(serialize = "array")] + NpArray, + /// `numpy.array2string()` array display helper. + #[strum(serialize = "array2string")] + NpArray2string, + /// `numpy.array_repr()` array repr helper. + #[strum(serialize = "array_repr")] + NpArrayRepr, + /// `numpy.array_str()` array string helper. + #[strum(serialize = "array_str")] + NpArrayStr, + /// `numpy.fromfunction()` callable coordinate-array constructor. + #[strum(serialize = "fromfunction")] + NpFromfunction, + /// `numpy.fromiter()` iterable numeric-array constructor. + #[strum(serialize = "fromiter")] + NpFromiter, + /// `numpy.fromstring()` text-mode numeric-array constructor. + #[strum(serialize = "fromstring")] + NpFromstring, + /// `numpy.asanyarray()` alias for `numpy.asarray()` in Monty's ndarray-only subset. + #[strum(serialize = "asanyarray")] + NpAsanyarray, + /// `numpy.zeros()` function + #[strum(serialize = "zeros")] + NpZeros, + /// `numpy.ones()` function + #[strum(serialize = "ones")] + NpOnes, + /// `numpy.subtract()` function + #[strum(serialize = "subtract")] + NpSubtract, + /// `numpy.multiply()` function + #[strum(serialize = "multiply")] + NpMultiply, + /// `numpy.divide()` function + #[strum(serialize = "divide")] + NpDivide, + /// `numpy.true_divide()` function + #[strum(serialize = "true_divide")] + NpTrueDivide, + /// `numpy.floor_divide()` function + #[strum(serialize = "floor_divide")] + NpFloorDivide, + /// `numpy.mod()` function + #[strum(serialize = "mod")] + NpMod, + /// `numpy.equal()` function + #[strum(serialize = "equal")] + NpEqual, + /// `numpy.not_equal()` function + #[strum(serialize = "not_equal")] + NpNotEqual, + /// `numpy.greater()` function + #[strum(serialize = "greater")] + NpGreater, + /// `numpy.greater_equal()` function + #[strum(serialize = "greater_equal")] + NpGreaterEqual, + /// `numpy.less()` function + #[strum(serialize = "less")] + NpLess, + /// `numpy.less_equal()` function + #[strum(serialize = "less_equal")] + NpLessEqual, + /// `numpy.arange()` function + #[strum(serialize = "arange")] + NpArange, + /// `numpy.linspace()` function + #[strum(serialize = "linspace")] + NpLinspace, + /// `numpy.where()` function + #[strum(serialize = "where")] + NpWhere, + /// `numpy.maximum()` function + Maximum, + /// `numpy.minimum()` function + Minimum, + /// `numpy.unique()` function + Unique, + /// `numpy.unique_values()` function + #[strum(serialize = "unique_values")] + NpUniqueValues, + /// `numpy.unique_counts()` function + #[strum(serialize = "unique_counts")] + NpUniqueCounts, + /// `numpy.unique_inverse()` function + #[strum(serialize = "unique_inverse")] + NpUniqueInverse, + /// `numpy.unique_all()` function + #[strum(serialize = "unique_all")] + NpUniqueAll, + /// `numpy.concatenate()` function + Concatenate, + /// `numpy.concat()` alias for `numpy.concatenate()`. + #[strum(serialize = "concat")] + NpConcat, + /// Shared: `mean()` method/function + Mean, + /// Shared: `std()` method/function + Std, + /// Shared: `abs()` method/function — also used by math module + Abs, + /// `ndarray.flat` attribute — returns flattened 1D view of the array. + #[strum(serialize = "flat")] + NpFlat, + /// `numpy.flatiter` public flat-iterator type object. + #[strum(serialize = "flatiter")] + NpFlatiter, + /// `ndarray.flatten()` method + Flatten, + /// `ndarray.tolist()` method + Tolist, + /// `ndarray.reshape()` method + Reshape, + /// `ndarray.argmin()` method + Argmin, + /// `ndarray.argmax()` method + Argmax, + /// `ndarray.all()` method + #[strum(serialize = "all")] + NpAll, + /// `ndarray.any()` method + #[strum(serialize = "any")] + NpAny, + /// `ndarray.argsort()` method + #[strum(serialize = "argsort")] + NpArgsort, + /// `numpy.argpartition()` function + #[strum(serialize = "argpartition")] + NpArgpartition, + /// `ndarray.astype()` method + #[strum(serialize = "astype")] + NpAstype, + /// `numpy.ndarray` type object. + #[strum(serialize = "ndarray")] + NpNdarray, + /// `numpy.ufunc` public ufunc type object. + #[strum(serialize = "ufunc")] + NpUfunc, + /// `numpy.format_float_positional()` function. + #[strum(serialize = "format_float_positional")] + NpFormatFloatPositional, + /// `numpy.format_float_scientific()` function. + #[strum(serialize = "format_float_scientific")] + NpFormatFloatScientific, + /// `ndarray.transpose()` method + #[strum(serialize = "transpose")] + NpTranspose, + /// `ndarray.size` attribute + #[strum(serialize = "size")] + NpSize, + /// `ndarray.ndim` attribute + #[strum(serialize = "ndim")] + NpNdim, + /// `ndarray.T` attribute (transpose) + #[strum(serialize = "T")] + NpT, + /// `ndarray.dtype` attribute + Dtype, + /// `ndarray.shape` attribute / also used by pathlib `parts` + #[strum(serialize = "shape")] + NpShape, + /// `numpy.broadcast_shapes()`. + #[strum(serialize = "broadcast_shapes")] + NpBroadcastShapes, + /// `numpy.broadcast_to()`. + #[strum(serialize = "broadcast_to")] + NpBroadcastTo, + /// `numpy.broadcast_arrays()`. + #[strum(serialize = "broadcast_arrays")] + NpBroadcastArrays, + /// `numpy.broadcast()`. + #[strum(serialize = "broadcast")] + NpBroadcast, + /// `ndarray.min()` / `numpy.min()` — shared with builtins + #[strum(serialize = "min")] + NpMin, + /// `numpy.amin()` alias for `numpy.min()`. + #[strum(serialize = "amin")] + NpAmin, + /// `ndarray.max()` / `numpy.max()` — shared with builtins + #[strum(serialize = "max")] + NpMax, + /// `numpy.amax()` alias for `numpy.max()`. + #[strum(serialize = "amax")] + NpAmax, + /// `ndarray.sum()` / `numpy.sum()` — shared with builtins + #[strum(serialize = "sum")] + NpSum, + /// `numpy.dot()` function / `ndarray.dot()` method + #[strum(serialize = "dot")] + Dot, + /// `numpy.cumsum()` function / `ndarray.cumsum()` method + #[strum(serialize = "cumsum")] + Cumsum, + /// `numpy.cumulative_sum()` alias for `numpy.cumsum()`. + #[strum(serialize = "cumulative_sum")] + NpCumulativeSum, + /// `numpy.clip()` function / `ndarray.clip()` method + #[strum(serialize = "clip")] + Clip, + /// `numpy.prod()` function / `ndarray.prod()` method + #[strum(serialize = "prod")] + NpProd, + /// `numpy.var()` function / `ndarray.var()` method + #[strum(serialize = "var")] + NpVar, + /// `numpy.full()` function + #[strum(serialize = "full")] + NpFull, + /// `numpy.eye()` function + #[strum(serialize = "eye")] + NpEye, + /// `numpy.empty()` function + #[strum(serialize = "empty")] + NpEmpty, + /// `numpy.zeros_like()` function + #[strum(serialize = "zeros_like")] + NpZerosLike, + /// `numpy.ones_like()` function + #[strum(serialize = "ones_like")] + NpOnesLike, + // Note: numpy.isnan/isinf/isfinite reuse math module's Isnan/Isinf/Isfinite + /// `numpy.array_equal()` function + #[strum(serialize = "array_equal")] + NpArrayEqual, + /// `numpy.array_equiv()` shape-compatible equality helper. + #[strum(serialize = "array_equiv")] + NpArrayEquiv, + /// `numpy.count_nonzero()` function + #[strum(serialize = "count_nonzero")] + NpCountNonzero, + /// `numpy.median()` function + #[strum(serialize = "median")] + NpMedian, + /// `numpy.power()` function + #[strum(serialize = "power")] + NpPower, + /// `numpy.choose()` indexed choice helper. + #[strum(serialize = "choose")] + NpChoose, + /// `numpy.diff()` function + #[strum(serialize = "diff")] + NpDiff, + /// `numpy.ediff1d()` flattened first-difference helper. + #[strum(serialize = "ediff1d")] + NpEdiff1d, + /// `numpy.fill_diagonal()` in-place diagonal fill helper. + #[strum(serialize = "fill_diagonal")] + NpFillDiagonal, + /// `numpy.put()` in-place flattened index assignment helper. + #[strum(serialize = "put")] + NpPut, + /// `numpy.put_along_axis()` in-place indexed assignment helper. + #[strum(serialize = "put_along_axis")] + NpPutAlongAxis, + /// `numpy.copyto()` in-place copy helper. + #[strum(serialize = "copyto")] + NpCopyto, + /// `numpy.putmask()` in-place masked assignment helper. + #[strum(serialize = "putmask")] + NpPutmask, + /// `numpy.place()` in-place masked placement helper. + #[strum(serialize = "place")] + NpPlace, + // Note: numpy.append reuses list's Append variant + /// `numpy.vstack()` function + #[strum(serialize = "vstack")] + NpVstack, + /// `numpy.hstack()` function + #[strum(serialize = "hstack")] + NpHstack, + /// `numpy.dstack()` function + #[strum(serialize = "dstack")] + NpDstack, + /// `numpy.stack()` function + #[strum(serialize = "stack")] + NpStack, + /// `numpy.block()` nested block assembly helper. + #[strum(serialize = "block")] + NpBlock, + /// `numpy.apply_along_axis()` callable 1-D slice helper. + #[strum(serialize = "apply_along_axis")] + NpApplyAlongAxis, + /// `numpy.apply_over_axes()` repeated axis-reduction helper. + #[strum(serialize = "apply_over_axes")] + NpApplyOverAxes, + /// `numpy.piecewise()` condition-list selection helper. + #[strum(serialize = "piecewise")] + NpPiecewise, + /// `numpy.pad()` array padding helper. + #[strum(serialize = "pad")] + NpPad, + /// `numpy.unstack()` function + #[strum(serialize = "unstack")] + NpUnstack, + /// `numpy.tile()` function + #[strum(serialize = "tile")] + NpTile, + /// `numpy.repeat()` function + #[strum(serialize = "repeat")] + NpRepeat, + // Note: numpy.split reuses string's Split variant + /// `numpy.nonzero()` function + #[strum(serialize = "nonzero")] + NpNonzero, + /// `numpy.argwhere()` function + #[strum(serialize = "argwhere")] + NpArgwhere, + /// `ndarray.ravel()` method + #[strum(serialize = "ravel")] + NpRavel, + // Note: numpy.min/max/sum/sort reuse existing StaticStrings variants + // (NpMin, NpMax, NpSum, Sort) which are already defined above. + + // --- Phase 2+ numpy functions --- + /// `numpy.newaxis` constant (alias for None) + #[strum(serialize = "newaxis")] + Newaxis, + /// `numpy.float64` dtype type object + #[strum(serialize = "float64")] + NpFloat64, + /// `numpy.int64` dtype type object + #[strum(serialize = "int64")] + NpInt64, + /// `numpy.bool_` dtype type object + #[strum(serialize = "bool_")] + NpBool_, + /// `numpy.float32` dtype alias (maps to float64 internally) + #[strum(serialize = "float32")] + NpFloat32, + /// `numpy.int32` dtype alias (maps to int64 internally) + #[strum(serialize = "int32")] + NpInt32, + /// `numpy.bool` dtype alias (maps to bool_ internally) + #[strum(serialize = "bool")] + NpBool, + /// `numpy.int_` dtype alias (maps to int64 internally) + #[strum(serialize = "int_")] + NpInt_, + /// `numpy.intc` dtype alias (maps to int32 internally) + #[strum(serialize = "intc")] + NpIntc, + /// `numpy.intp` dtype alias (maps to int64 internally) + #[strum(serialize = "intp")] + NpIntp, + /// `numpy.long` dtype alias (maps to int64 internally) + #[strum(serialize = "long")] + NpLong, + /// `numpy.longlong` dtype alias (maps to int64 internally) + #[strum(serialize = "longlong")] + NpLonglong, + /// `numpy.byte` dtype alias (maps to int64 internally) + #[strum(serialize = "byte")] + NpByte, + /// `numpy.short` dtype alias (maps to int64 internally) + #[strum(serialize = "short")] + NpShort, + /// `numpy.int8` dtype alias (maps to int64 internally) + #[strum(serialize = "int8")] + NpInt8, + /// `numpy.int16` dtype alias (maps to int64 internally) + #[strum(serialize = "int16")] + NpInt16, + /// `numpy.uint` dtype alias (maps to int64 internally) + #[strum(serialize = "uint")] + NpUint, + /// `numpy.uintc` dtype alias (maps to int32 internally) + #[strum(serialize = "uintc")] + NpUintc, + /// `numpy.uintp` dtype alias (maps to int64 internally) + #[strum(serialize = "uintp")] + NpUintp, + /// `numpy.ubyte` dtype alias (maps to int64 internally) + #[strum(serialize = "ubyte")] + NpUbyte, + /// `numpy.ushort` dtype alias (maps to int64 internally) + #[strum(serialize = "ushort")] + NpUshort, + /// `numpy.uint8` dtype alias (maps to int64 internally) + #[strum(serialize = "uint8")] + NpUint8, + /// `numpy.uint16` dtype alias (maps to int64 internally) + #[strum(serialize = "uint16")] + NpUint16, + /// `numpy.uint32` dtype alias (maps to int64 internally) + #[strum(serialize = "uint32")] + NpUint32, + /// `numpy.uint64` dtype alias (maps to int64 internally) + #[strum(serialize = "uint64")] + NpUint64, + /// `numpy.ulong` dtype alias (maps to int64 internally) + #[strum(serialize = "ulong")] + NpUlong, + /// `numpy.ulonglong` dtype alias (maps to int64 internally) + #[strum(serialize = "ulonglong")] + NpUlonglong, + /// `numpy.float16` dtype alias (maps to float32 internally) + #[strum(serialize = "float16")] + NpFloat16, + /// `numpy.half` dtype alias (maps to float32 internally) + #[strum(serialize = "half")] + NpHalf, + /// `numpy.single` dtype alias (maps to float32 internally) + #[strum(serialize = "single")] + NpSingle, + /// `numpy.double` dtype alias (maps to float64 internally) + #[strum(serialize = "double")] + NpDouble, + /// `numpy.longdouble` dtype alias (maps to float64 internally) + #[strum(serialize = "longdouble")] + NpLongdouble, + /// `numpy.integer` dtype category marker. + #[strum(serialize = "integer")] + NpInteger, + /// `numpy.floating` dtype category marker. + #[strum(serialize = "floating")] + NpFloating, + /// `numpy.inexact` dtype category marker. + #[strum(serialize = "inexact")] + NpInexact, + /// Internal marker value for `numpy.integer`. + #[strum(serialize = "__monty_numpy_integer_category")] + NpIntegerCategoryMarker, + /// Internal marker value for `numpy.floating`. + #[strum(serialize = "__monty_numpy_floating_category")] + NpFloatingCategoryMarker, + /// Internal marker value for `numpy.inexact`. + #[strum(serialize = "__monty_numpy_inexact_category")] + NpInexactCategoryMarker, + /// `numpy.can_cast()` compact dtype cast predicate. + #[strum(serialize = "can_cast")] + NpCanCast, + /// `numpy.promote_types()` compact dtype promotion helper. + #[strum(serialize = "promote_types")] + NpPromoteTypes, + /// `numpy.result_type()` compact dtype result helper. + #[strum(serialize = "result_type")] + NpResultType, + /// `numpy.common_type()` compact common dtype helper. + #[strum(serialize = "common_type")] + NpCommonType, + /// `numpy.min_scalar_type()` compact scalar dtype helper. + #[strum(serialize = "min_scalar_type")] + NpMinScalarType, + /// `numpy.mintypecode()` legacy dtype character helper. + #[strum(serialize = "mintypecode")] + NpMintypecode, + /// `numpy.typename()` legacy dtype character name helper. + #[strum(serialize = "typename")] + NpTypename, + /// `numpy.typecodes` legacy dtype character mapping. + #[strum(serialize = "typecodes")] + NpTypecodes, + /// `numpy.sctypeDict` legacy scalar dtype mapping. + #[strum(serialize = "sctypeDict")] + NpSctypeDict, + /// `numpy.info()` documentation helper placeholder. + #[strum(serialize = "info")] + NpInfo, + /// `numpy.issubdtype()` compact dtype category predicate. + #[strum(serialize = "issubdtype")] + NpIssubdtype, + /// `numpy.isdtype()` compact dtype kind predicate. + #[strum(serialize = "isdtype")] + NpIsdtype, + /// `numpy.finfo()` floating dtype limit metadata. + #[strum(serialize = "finfo")] + NpFinfo, + /// `numpy.iinfo()` integer dtype limit metadata. + #[strum(serialize = "iinfo")] + NpIinfo, + /// `numpy.geterr()` floating-point error config helper. + #[strum(serialize = "geterr")] + NpGeterr, + /// `numpy.seterr()` floating-point error config helper. + #[strum(serialize = "seterr")] + NpSeterr, + /// `numpy.geterrcall()` floating-point error callback helper. + #[strum(serialize = "geterrcall")] + NpGeterrcall, + /// `numpy.seterrcall()` floating-point error callback helper. + #[strum(serialize = "seterrcall")] + NpSeterrcall, + /// `numpy.errstate()` floating-point error context helper. + #[strum(serialize = "errstate")] + NpErrstate, + /// `numpy.get_printoptions()` print config helper. + #[strum(serialize = "get_printoptions")] + NpGetPrintoptions, + /// `numpy.set_printoptions()` print config helper. + #[strum(serialize = "set_printoptions")] + NpSetPrintoptions, + /// `numpy.printoptions()` print config context helper. + #[strum(serialize = "printoptions")] + NpPrintoptions, + /// `numpy.getbufsize()` legacy buffer-size helper. + #[strum(serialize = "getbufsize")] + NpGetbufsize, + /// `numpy.setbufsize()` legacy buffer-size helper. + #[strum(serialize = "setbufsize")] + NpSetbufsize, + /// `numpy.show_runtime()` no-host runtime display helper. + #[strum(serialize = "show_runtime")] + NpShowRuntime, + /// `numpy.test()` no-op test-runner helper. + #[strum(serialize = "test")] + NpTest, + /// `numpy.little_endian` constant + #[strum(serialize = "little_endian")] + NpLittleEndian, + /// `numpy.euler_gamma` constant + #[strum(serialize = "euler_gamma")] + NpEulerGamma, + /// `numpy.arcsin()` / `numpy.asin()` function + #[strum(serialize = "arcsin")] + NpArcsin, + /// `numpy.arccos()` / `numpy.acos()` function + #[strum(serialize = "arccos")] + NpArccos, + /// `numpy.arctan()` / `numpy.atan()` function + #[strum(serialize = "arctan")] + NpArctan, + /// `numpy.arctan2()` function — two-argument arctangent + #[strum(serialize = "arctan2")] + NpArctan2, + /// `numpy.angle()` function for real-valued phase angles. + #[strum(serialize = "angle")] + NpAngle, + /// `numpy.arcsinh()` function + #[strum(serialize = "arcsinh")] + NpArcsinh, + /// `numpy.arccosh()` function + #[strum(serialize = "arccosh")] + NpArccosh, + /// `numpy.arctanh()` function + #[strum(serialize = "arctanh")] + NpArctanh, + /// `numpy.sign()` function + #[strum(serialize = "sign")] + NpSign, + /// `numpy.square()` function + #[strum(serialize = "square")] + NpSquare, + /// `numpy.reciprocal()` function + #[strum(serialize = "reciprocal")] + NpReciprocal, + /// `numpy.deg2rad()` function + #[strum(serialize = "deg2rad")] + NpDeg2rad, + /// `numpy.rad2deg()` function + #[strum(serialize = "rad2deg")] + NpRad2deg, + /// `numpy.hypot()` function — hypotenuse + #[strum(serialize = "hypot")] + NpHypot, + /// `numpy.nan_to_num()` function + #[strum(serialize = "nan_to_num")] + NpNanToNum, + /// `numpy.fmin()` function — NaN-ignoring minimum + #[strum(serialize = "fmin")] + NpFmin, + /// `numpy.fmax()` function — NaN-ignoring maximum + #[strum(serialize = "fmax")] + NpFmax, + /// `numpy.rint()` function — round to nearest integer + #[strum(serialize = "rint")] + NpRint, + /// `numpy.around()` alias for `numpy.round()`. + #[strum(serialize = "around")] + NpAround, + /// `numpy.positive()` function — unary + + #[strum(serialize = "positive")] + NpPositive, + /// `numpy.negative()` function — unary - + #[strum(serialize = "negative")] + NpNegative, + /// `numpy.logaddexp()` function. + #[strum(serialize = "logaddexp")] + NpLogaddexp, + /// `numpy.logaddexp2()` function. + #[strum(serialize = "logaddexp2")] + NpLogaddexp2, + /// `numpy.spacing()` function. + #[strum(serialize = "spacing")] + NpSpacing, + /// `numpy.signbit()` function. + #[strum(serialize = "signbit")] + NpSignbit, + /// `numpy.sinc()` function. + #[strum(serialize = "sinc")] + NpSinc, + /// `numpy.heaviside()` function. + #[strum(serialize = "heaviside")] + NpHeaviside, + /// `numpy.fix()` function. + #[strum(serialize = "fix")] + NpFix, + /// `numpy.float_power()` function. + #[strum(serialize = "float_power")] + NpFloatPower, + /// `numpy.divmod()` function. + #[strum(serialize = "divmod")] + NpDivmod, + /// `numpy.bitwise_and()` integer/boolean bitwise AND. + #[strum(serialize = "bitwise_and")] + NpBitwiseAnd, + /// `numpy.bitwise_or()` integer/boolean bitwise OR. + #[strum(serialize = "bitwise_or")] + NpBitwiseOr, + /// `numpy.bitwise_xor()` integer/boolean bitwise XOR. + #[strum(serialize = "bitwise_xor")] + NpBitwiseXor, + /// `numpy.bitwise_not()` integer/boolean bitwise inversion. + #[strum(serialize = "bitwise_not")] + NpBitwiseNot, + /// `numpy.bitwise_invert()` alias for integer/boolean bitwise inversion. + #[strum(serialize = "bitwise_invert")] + NpBitwiseInvert, + /// `numpy.invert()` alias for integer/boolean bitwise inversion. + #[strum(serialize = "invert")] + NpInvert, + /// `numpy.left_shift()` integer bit shift helper. + #[strum(serialize = "left_shift")] + NpLeftShift, + /// `numpy.right_shift()` integer bit shift helper. + #[strum(serialize = "right_shift")] + NpRightShift, + /// `numpy.bitwise_left_shift()` alias for left shift. + #[strum(serialize = "bitwise_left_shift")] + NpBitwiseLeftShift, + /// `numpy.bitwise_right_shift()` alias for right shift. + #[strum(serialize = "bitwise_right_shift")] + NpBitwiseRightShift, + /// `numpy.bitwise_count()` integer population count helper. + #[strum(serialize = "bitwise_count")] + NpBitwiseCount, + /// `numpy.packbits()` packs non-zero bits into bytes. + #[strum(serialize = "packbits")] + NpPackbits, + /// `numpy.unpackbits()` unpacks byte values into bit arrays. + #[strum(serialize = "unpackbits")] + NpUnpackbits, + /// `numpy.bartlett()` window generator. + #[strum(serialize = "bartlett")] + NpBartlett, + /// `numpy.blackman()` window generator. + #[strum(serialize = "blackman")] + NpBlackman, + /// `numpy.hamming()` window generator. + #[strum(serialize = "hamming")] + NpHamming, + /// `numpy.hanning()` window generator. + #[strum(serialize = "hanning")] + NpHanning, + /// `numpy.kaiser()` window generator. + #[strum(serialize = "kaiser")] + NpKaiser, + /// `numpy.i0()` modified Bessel function helper. + #[strum(serialize = "i0")] + NpI0, + /// `numpy.base_repr()` integer base conversion helper. + #[strum(serialize = "base_repr")] + NpBaseRepr, + /// `numpy.binary_repr()` integer binary conversion helper. + #[strum(serialize = "binary_repr")] + NpBinaryRepr, + /// `numpy.conj()` real-valued conjugate helper. + #[strum(serialize = "conj")] + NpConj, + /// `numpy.conjugate()` alias for `numpy.conj()`. + #[strum(serialize = "conjugate")] + NpConjugate, + /// `numpy.real()` real component helper. + #[strum(serialize = "real")] + NpReal, + /// `numpy.real_if_close()` real-valued identity helper for Monty's numeric subset. + #[strum(serialize = "real_if_close")] + NpRealIfClose, + /// `numpy.imag()` imaginary component helper. + #[strum(serialize = "imag")] + NpImag, + /// `numpy.isreal()` element-wise real-valued predicate. + #[strum(serialize = "isreal")] + NpIsreal, + /// `numpy.isrealobj()` object-level real-valued predicate. + #[strum(serialize = "isrealobj")] + NpIsrealobj, + /// `numpy.isposinf()` element-wise positive infinity predicate. + #[strum(serialize = "isposinf")] + NpIsposinf, + /// `numpy.isneginf()` element-wise negative infinity predicate. + #[strum(serialize = "isneginf")] + NpIsneginf, + /// `numpy.iscomplex()` element-wise complex-valued predicate. + #[strum(serialize = "iscomplex")] + NpIscomplex, + /// `numpy.iscomplexobj()` object-level complex-valued predicate. + #[strum(serialize = "iscomplexobj")] + NpIscomplexobj, + /// `numpy.isscalar()` scalar predicate. + #[strum(serialize = "isscalar")] + NpIsscalar, + /// `numpy.iterable()` iterable predicate. + #[strum(serialize = "iterable")] + NpIterable, + /// `numpy.atleast_1d()` shape helper. + #[strum(serialize = "atleast_1d")] + NpAtleast1d, + /// `numpy.atleast_2d()` shape helper. + #[strum(serialize = "atleast_2d")] + NpAtleast2d, + /// `numpy.atleast_3d()` shape helper. + #[strum(serialize = "atleast_3d")] + NpAtleast3d, + /// `numpy.diag_indices()` index helper. + #[strum(serialize = "diag_indices")] + NpDiagIndices, + /// `numpy.diag_indices_from()` index helper. + #[strum(serialize = "diag_indices_from")] + NpDiagIndicesFrom, + /// `numpy.tril_indices()` lower-triangle index helper. + #[strum(serialize = "tril_indices")] + NpTrilIndices, + /// `numpy.tril_indices_from()` lower-triangle index helper. + #[strum(serialize = "tril_indices_from")] + NpTrilIndicesFrom, + /// `numpy.triu_indices()` upper-triangle index helper. + #[strum(serialize = "triu_indices")] + NpTriuIndices, + /// `numpy.triu_indices_from()` upper-triangle index helper. + #[strum(serialize = "triu_indices_from")] + NpTriuIndicesFrom, + /// `numpy.indices()` dense coordinate grid helper. + #[strum(serialize = "indices")] + NpIndices, + /// `numpy.unravel_index()` flat-to-coordinate index helper. + #[strum(serialize = "unravel_index")] + NpUnravelIndex, + /// `numpy.ravel_multi_index()` coordinate-to-flat index helper. + #[strum(serialize = "ravel_multi_index")] + NpRavelMultiIndex, + /// `numpy.ndindex()` row-major coordinate iterator helper. + #[strum(serialize = "ndindex")] + NpNdindex, + /// `numpy.ndenumerate()` row-major index/value iterator helper. + #[strum(serialize = "ndenumerate")] + NpNdenumerate, + /// `numpy.nditer()` row-major value iterator helper. + #[strum(serialize = "nditer")] + NpNditer, + /// `numpy.nansum()` function + #[strum(serialize = "nansum")] + NpNansum, + /// `numpy.nanmean()` function + #[strum(serialize = "nanmean")] + NpNanmean, + /// `numpy.nanmin()` function + #[strum(serialize = "nanmin")] + NpNanmin, + /// `numpy.nanmax()` function + #[strum(serialize = "nanmax")] + NpNanmax, + /// `numpy.nanstd()` function + #[strum(serialize = "nanstd")] + NpNanstd, + /// `numpy.nanvar()` function + #[strum(serialize = "nanvar")] + NpNanvar, + /// `numpy.nanprod()` function + #[strum(serialize = "nanprod")] + NpNanprod, + /// `numpy.nanmedian()` function + #[strum(serialize = "nanmedian")] + NpNanmedian, + /// `numpy.nanpercentile()` function + #[strum(serialize = "nanpercentile")] + NpNanpercentile, + /// `numpy.nanquantile()` function + #[strum(serialize = "nanquantile")] + NpNanquantile, + /// `numpy.nanargmin()` function + #[strum(serialize = "nanargmin")] + NpNanargmin, + /// `numpy.nanargmax()` function + #[strum(serialize = "nanargmax")] + NpNanargmax, + /// `numpy.average()` function + #[strum(serialize = "average")] + NpAverage, + /// `numpy.percentile()` function + #[strum(serialize = "percentile")] + NpPercentile, + /// `numpy.quantile()` function + #[strum(serialize = "quantile")] + NpQuantile, + /// `numpy.histogram()` one-dimensional histogram helper. + #[strum(serialize = "histogram")] + NpHistogram, + /// `numpy.histogram2d()` two-dimensional histogram helper. + #[strum(serialize = "histogram2d")] + NpHistogram2d, + /// `numpy.histogram_bin_edges()` one-dimensional bin edge helper. + #[strum(serialize = "histogram_bin_edges")] + NpHistogramBinEdges, + /// `numpy.histogramdd()` multi-dimensional histogram helper. + #[strum(serialize = "histogramdd")] + NpHistogramdd, + /// `numpy.ptp()` function — peak to peak + #[strum(serialize = "ptp")] + NpPtp, + /// `numpy.cumprod()` function + #[strum(serialize = "cumprod")] + NpCumprod, + /// `numpy.cumulative_prod()` alias for `numpy.cumprod()`. + #[strum(serialize = "cumulative_prod")] + NpCumulativeProd, + /// `numpy.logical_and()` function + #[strum(serialize = "logical_and")] + NpLogicalAnd, + /// `numpy.logical_or()` function + #[strum(serialize = "logical_or")] + NpLogicalOr, + /// `numpy.logical_not()` function + #[strum(serialize = "logical_not")] + NpLogicalNot, + /// `numpy.logical_xor()` function + #[strum(serialize = "logical_xor")] + NpLogicalXor, + /// `numpy.allclose()` function + #[strum(serialize = "allclose")] + NpAllclose, + /// `numpy.isin()` function + #[strum(serialize = "isin")] + NpIsin, + /// `numpy.flip()` function + #[strum(serialize = "flip")] + NpFlip, + /// `numpy.fliplr()` function + #[strum(serialize = "fliplr")] + NpFliplr, + /// `numpy.flipud()` function + #[strum(serialize = "flipud")] + NpFlipud, + /// `numpy.roll()` function + #[strum(serialize = "roll")] + NpRoll, + /// `numpy.expand_dims()` function + #[strum(serialize = "expand_dims")] + NpExpandDims, + /// `numpy.squeeze()` function + #[strum(serialize = "squeeze")] + NpSqueeze, + /// `numpy.delete()` function + #[strum(serialize = "delete")] + NpDelete, + /// `numpy.diag()` function + #[strum(serialize = "diag")] + NpDiag, + /// `numpy.diagflat()` function + #[strum(serialize = "diagflat")] + NpDiagflat, + /// `numpy.diagonal()` function + #[strum(serialize = "diagonal")] + NpDiagonal, + /// `numpy.trace()` function + #[strum(serialize = "trace")] + NpTrace, + /// `numpy.flatnonzero()` function + #[strum(serialize = "flatnonzero")] + NpFlatnonzero, + /// `numpy.asarray()` function + #[strum(serialize = "asarray")] + NpAsarray, + /// `numpy.asarray_chkfinite()` finite-checking array conversion helper. + #[strum(serialize = "asarray_chkfinite")] + NpAsarrayChkfinite, + /// `numpy.ascontiguousarray()` contiguous array conversion helper. + #[strum(serialize = "ascontiguousarray")] + NpAscontiguousarray, + /// `numpy.asfortranarray()` Fortran array conversion helper. + #[strum(serialize = "asfortranarray")] + NpAsfortranarray, + /// `numpy.require()` array requirement helper. + #[strum(serialize = "require")] + NpRequire, + /// `numpy.ix_()` open mesh index helper. + #[strum(serialize = "ix_")] + NpIx_, + /// `numpy.mask_indices()` triangular mask index helper. + #[strum(serialize = "mask_indices")] + NpMaskIndices, + /// `numpy.isfortran()` memory layout predicate. + #[strum(serialize = "isfortran")] + NpIsfortran, + /// `numpy.may_share_memory()` conservative memory overlap predicate. + #[strum(serialize = "may_share_memory")] + NpMayShareMemory, + /// `numpy.shares_memory()` exact memory overlap predicate. + #[strum(serialize = "shares_memory")] + NpSharesMemory, + /// `numpy.column_stack()` function + #[strum(serialize = "column_stack")] + NpColumnStack, + /// `numpy.row_stack()` function — alias for vstack + #[strum(serialize = "row_stack")] + NpRowStack, + /// `numpy.hsplit()` function + #[strum(serialize = "hsplit")] + NpHsplit, + /// `numpy.vsplit()` function + #[strum(serialize = "vsplit")] + NpVsplit, + /// `numpy.dsplit()` function + #[strum(serialize = "dsplit")] + NpDsplit, + /// `numpy.array_split()` function + #[strum(serialize = "array_split")] + NpArraySplit, + /// `numpy.searchsorted()` function + #[strum(serialize = "searchsorted")] + NpSearchsorted, + /// `numpy.lexsort()` indirect stable sorting helper. + #[strum(serialize = "lexsort")] + NpLexsort, + /// `numpy.cov()` covariance helper. + #[strum(serialize = "cov")] + NpCov, + /// `numpy.corrcoef()` correlation coefficient helper. + #[strum(serialize = "corrcoef")] + NpCorrcoef, + /// `numpy.extract()` function + #[strum(serialize = "extract")] + NpExtract, + /// `numpy.trim_zeros()` one-dimensional zero trimming helper. + #[strum(serialize = "trim_zeros")] + NpTrimZeros, + /// `numpy.unwrap()` phase-unwrapping helper. + #[strum(serialize = "unwrap")] + NpUnwrap, + /// `numpy.intersect1d()` function + #[strum(serialize = "intersect1d")] + NpIntersect1d, + /// `numpy.union1d()` function + #[strum(serialize = "union1d")] + NpUnion1d, + /// `numpy.setdiff1d()` function + #[strum(serialize = "setdiff1d")] + NpSetdiff1d, + /// `numpy.setxor1d()` function + #[strum(serialize = "setxor1d")] + NpSetxor1d, + /// `numpy.bincount()` function + #[strum(serialize = "bincount")] + NpBincount, + /// `numpy.digitize()` function + #[strum(serialize = "digitize")] + NpDigitize, + /// `numpy.matmul()` function + #[strum(serialize = "matmul")] + NpMatmul, + /// `numpy.inner()` function + #[strum(serialize = "inner")] + NpInner, + /// `numpy.outer()` function + #[strum(serialize = "outer")] + NpOuter, + /// `numpy.vdot()` function + #[strum(serialize = "vdot")] + NpVdot, + /// `numpy.vecdot()` function + #[strum(serialize = "vecdot")] + NpVecdot, + /// `numpy.matvec()` function + #[strum(serialize = "matvec")] + NpMatvec, + /// `numpy.vecmat()` function + #[strum(serialize = "vecmat")] + NpVecmat, + /// `numpy.cross()` function + #[strum(serialize = "cross")] + NpCross, + /// `numpy.kron()` Kronecker product helper. + #[strum(serialize = "kron")] + NpKron, + /// `numpy.tensordot()` generalized tensor contraction helper. + #[strum(serialize = "tensordot")] + NpTensordot, + /// `numpy.einsum()` explicit-subscript contraction helper. + #[strum(serialize = "einsum")] + NpEinsum, + /// `numpy.einsum_path()` simple contraction planner helper. + #[strum(serialize = "einsum_path")] + NpEinsumPath, + /// `numpy.trapezoid()` function + #[strum(serialize = "trapezoid")] + NpTrapezoid, + /// `numpy.vander()` function + #[strum(serialize = "vander")] + NpVander, + /// `numpy.poly()` polynomial roots-to-coefficients helper. + #[strum(serialize = "poly")] + NpPoly, + /// `numpy.polyadd()` polynomial addition helper. + #[strum(serialize = "polyadd")] + NpPolyadd, + /// `numpy.polysub()` polynomial subtraction helper. + #[strum(serialize = "polysub")] + NpPolysub, + /// `numpy.polymul()` polynomial multiplication helper. + #[strum(serialize = "polymul")] + NpPolymul, + /// `numpy.polydiv()` polynomial division helper. + #[strum(serialize = "polydiv")] + NpPolydiv, + /// `numpy.polyint()` polynomial integration helper. + #[strum(serialize = "polyint")] + NpPolyint, + /// `numpy.polyder()` polynomial derivative helper. + #[strum(serialize = "polyder")] + NpPolyder, + /// `numpy.polyval()` polynomial evaluation helper. + #[strum(serialize = "polyval")] + NpPolyval, + /// `numpy.logspace()` function + #[strum(serialize = "logspace")] + NpLogspace, + /// `numpy.geomspace()` function + #[strum(serialize = "geomspace")] + NpGeomspace, + /// `numpy.tri()` function + #[strum(serialize = "tri")] + NpTri, + /// `numpy.tril()` function + #[strum(serialize = "tril")] + NpTril, + /// `numpy.triu()` function + #[strum(serialize = "triu")] + NpTriu, + /// `numpy.identity()` function — alias for eye + #[strum(serialize = "identity")] + NpIdentity, + /// `numpy.meshgrid()` function + #[strum(serialize = "meshgrid")] + NpMeshgrid, + /// `numpy.full_like()` function + #[strum(serialize = "full_like")] + NpFullLike, + /// `numpy.empty_like()` function + #[strum(serialize = "empty_like")] + NpEmptyLike, + /// `numpy.gradient()` function + #[strum(serialize = "gradient")] + NpGradient, + /// `numpy.convolve()` function + #[strum(serialize = "convolve")] + NpConvolve, + /// `numpy.correlate()` function + #[strum(serialize = "correlate")] + NpCorrelate, + /// `numpy.interp()` function — 1D interpolation + #[strum(serialize = "interp")] + NpInterp, + /// `numpy.select()` function + #[strum(serialize = "select")] + NpSelect, + /// `ndarray.item()` method — extract scalar from single-element array + #[strum(serialize = "item")] + NpItem, + /// `numpy.take()` function / `ndarray.take()` method — take elements at indices + #[strum(serialize = "take")] + NpTake, + /// `numpy.take_along_axis()` indexed gather helper. + #[strum(serialize = "take_along_axis")] + NpTakeAlongAxis, + /// `numpy.resize()` repeated flattened resize helper. + #[strum(serialize = "resize")] + NpResize, + /// `ndarray.fill()` method — fill array with value + #[strum(serialize = "fill")] + NpFill, + /// `numpy.compress()` function / `ndarray.compress()` method — select elements by boolean condition + #[strum(serialize = "compress")] + NpCompress, + /// `numpy.swapaxes()` function / `ndarray.swapaxes()` method + #[strum(serialize = "swapaxes")] + NpSwapaxes, + /// `numpy.permute_dims()` function — permute ndarray axes + #[strum(serialize = "permute_dims")] + NpPermuteDims, + /// `numpy.matrix_transpose()` function — swap the last two axes + #[strum(serialize = "matrix_transpose")] + NpMatrixTranspose, + /// `numpy.moveaxis()` function — move axes to new positions + #[strum(serialize = "moveaxis")] + NpMoveaxis, + /// `numpy.rollaxis()` function — roll one axis backward + #[strum(serialize = "rollaxis")] + NpRollaxis, + /// `numpy.rot90()` function — rotate a 2-D array by quarter turns + #[strum(serialize = "rot90")] + NpRot90, + /// `ndarray.nbytes` attribute + #[strum(serialize = "nbytes")] + NpNbytes, + /// `ndarray.itemsize` attribute + #[strum(serialize = "itemsize")] + NpItemsize, + /// `numpy.nancumsum()` function + #[strum(serialize = "nancumsum")] + NpNancumsum, + /// `numpy.nancumprod()` function + #[strum(serialize = "nancumprod")] + NpNancumprod, + // ========================== // gc module strings (only reachable when the `test-hooks` feature is enabled, // but interned unconditionally so the variant ordering — and therefore every @@ -612,6 +1734,112 @@ pub enum StaticStrings { Disable, /// `gc.enable()` function. Enable, + + // ========================== + // Late-added NumPy compatibility attributes. These intentionally live at + // the end of the enum so existing static string IDs stay stable. + /// `numpy.False_` value-compatible bool scalar constant. + #[strum(serialize = "False_")] + NpFalseScalar, + /// `numpy.True_` value-compatible bool scalar constant. + #[strum(serialize = "True_")] + NpTrueScalar, + /// `numpy.generic` dtype category marker. + #[strum(serialize = "generic")] + NpGeneric, + /// `numpy.number` dtype category marker. + #[strum(serialize = "number")] + NpNumber, + /// `numpy.signedinteger` dtype category marker. + #[strum(serialize = "signedinteger")] + NpSignedInteger, + /// `numpy.unsignedinteger` dtype category marker. + #[strum(serialize = "unsignedinteger")] + NpUnsignedInteger, + /// `numpy.complexfloating` dtype category marker. + #[strum(serialize = "complexfloating")] + NpComplexFloating, + /// `numpy.flexible` dtype category marker. + #[strum(serialize = "flexible")] + NpFlexible, + /// `numpy.character` dtype category marker. + #[strum(serialize = "character")] + NpCharacter, + /// Internal marker value for `numpy.generic`. + #[strum(serialize = "__monty_numpy_generic_category")] + NpGenericCategoryMarker, + /// Internal marker value for `numpy.number`. + #[strum(serialize = "__monty_numpy_number_category")] + NpNumberCategoryMarker, + /// Internal marker value for `numpy.signedinteger`. + #[strum(serialize = "__monty_numpy_signedinteger_category")] + NpSignedIntegerCategoryMarker, + /// Internal marker value for `numpy.unsignedinteger`. + #[strum(serialize = "__monty_numpy_unsignedinteger_category")] + NpUnsignedIntegerCategoryMarker, + /// Internal marker value for `numpy.complexfloating`. + #[strum(serialize = "__monty_numpy_complexfloating_category")] + NpComplexFloatingCategoryMarker, + /// Internal marker value for `numpy.flexible`. + #[strum(serialize = "__monty_numpy_flexible_category")] + NpFlexibleCategoryMarker, + /// Internal marker value for `numpy.character`. + #[strum(serialize = "__monty_numpy_character_category")] + NpCharacterCategoryMarker, + /// `numpy.index_exp` index-expression helper. + #[strum(serialize = "index_exp")] + NpIndexExp, + /// `numpy.s_` index-expression helper. + #[strum(serialize = "s_")] + NpSIndex, + /// `numpy.mgrid` dense grid index helper. + #[strum(serialize = "mgrid")] + NpMgrid, + /// `numpy.ogrid` open grid index helper. + #[strum(serialize = "ogrid")] + NpOgrid, + /// `numpy.r_` row-concatenation index helper. + #[strum(serialize = "r_")] + NpRIndex, + /// `numpy.c_` column-concatenation index helper. + #[strum(serialize = "c_")] + NpCIndex, + /// `numpy.complex64` dtype metadata marker. + #[strum(serialize = "complex64")] + NpComplex64, + /// `numpy.complex128` dtype metadata marker. + #[strum(serialize = "complex128")] + NpComplex128, + /// `numpy.cdouble` dtype metadata alias for `complex128`. + #[strum(serialize = "cdouble")] + NpCdouble, + /// `numpy.csingle` dtype metadata alias for `complex64`. + #[strum(serialize = "csingle")] + NpCsingle, + /// `numpy.clongdouble` dtype metadata marker. + #[strum(serialize = "clongdouble")] + NpClongdouble, + /// `numpy.str_` dtype metadata marker. + #[strum(serialize = "str_")] + NpStr_, + /// `numpy.bytes_` dtype metadata marker. + #[strum(serialize = "bytes_")] + NpBytes_, + /// `numpy.void` dtype metadata marker. + #[strum(serialize = "void")] + NpVoid, + /// `numpy.object_` dtype metadata marker. + #[strum(serialize = "object_")] + NpObject_, + /// `numpy.datetime64` dtype metadata marker. + #[strum(serialize = "datetime64")] + NpDatetime64, + /// `numpy.timedelta64` dtype metadata marker. + #[strum(serialize = "timedelta64")] + NpTimedelta64, + /// `numpy.ScalarType` tuple for Monty's supported scalar constructors. + #[strum(serialize = "ScalarType")] + NpScalarType, } impl StaticStrings { @@ -730,20 +1958,16 @@ pub struct InternerBuilder { } impl InternerBuilder { - /// Creates a new string interner with pre-interned strings. + /// Creates a new interner for code-specific strings, bytes, and integers. /// - /// Clones from a lazily-initialized base interner that contains all pre-interned - /// strings (``, attribute names, ASCII chars). This avoids rebuilding - /// the base set on every call. + /// ASCII and [`StaticStrings`] values use deterministic `StringId` ranges and are + /// not stored in this builder. The builder only owns dynamically interned strings + /// that are specific to the code being parsed, which keeps new parser instances + /// cheap to create. /// /// # Arguments /// * `code` - The code being parsed, used for a very rough guess at how many - /// additional strings will be interned beyond the base set. - /// - /// Pre-interns (via `BASE_INTERNER`): - /// - Index 0: `""` for module-level code - /// - Indices 1-MAX_ATTR_ID: Known attribute names (append, insert, get, join, etc.) - /// - Indices MAX_ATTR_ID+1..: ASCII single-character strings + /// dynamic strings will be interned. pub fn new(code: &str) -> Self { // Reserve capacity for code-specific strings // Rough guess: count quotes and divide by 2 (open+close per string) @@ -791,6 +2015,20 @@ impl InternerBuilder { StringId::from_ascii(s.as_bytes()[0]) } else if let Ok(ss) = StaticStrings::from_str(s) { ss.into() + } else { + self.intern_dynamic(s) + } + } + + /// Interns a string without checking the static-string table. + /// + /// Use this for arbitrary host-provided labels such as filenames. Those labels + /// still need stable `StringId` values during one compilation, but probing every + /// known Python and NumPy attribute name on predictable misses adds fixed parse + /// overhead to tiny programs. + pub(crate) fn intern_dynamic(&mut self, s: &str) -> StringId { + if s.len() == 1 { + StringId::from_ascii(s.as_bytes()[0]) } else { *self.string_map.entry(s.to_owned()).or_insert_with(|| { let string_id = self.strings.len() + INTERN_STRING_ID_OFFSET; diff --git a/crates/monty/src/modules/mod.rs b/crates/monty/src/modules/mod.rs index 9af0c39f6..b451127e2 100644 --- a/crates/monty/src/modules/mod.rs +++ b/crates/monty/src/modules/mod.rs @@ -22,6 +22,7 @@ pub(crate) mod datetime; pub(crate) mod gc; pub(crate) mod json; pub(crate) mod math; +pub(crate) mod numpy; pub(crate) mod os; pub(crate) mod pathlib; pub(crate) mod re; @@ -33,23 +34,28 @@ pub(crate) mod typing; #[derive(Debug, Clone, Copy, PartialEq, Eq, FromRepr)] pub(crate) enum StandardLib { /// The `sys` module providing system-specific parameters and functions. - Sys, + Sys = 0, /// The `typing` module providing type hints support. - Typing, + Typing = 1, /// The `asyncio` module providing async/await support (only `gather()` implemented). - Asyncio, + Asyncio = 2, /// The `pathlib` module providing object-oriented filesystem paths. - Pathlib, + Pathlib = 3, /// The `os` module providing operating system interface (only `getenv()` implemented). - Os, + Os = 4, /// The `math` module providing mathematical functions and constants. - Math, + Math = 5, /// The `json` module providing JSON parsing and serialization. - Json, + Json = 6, /// The `re` module providing regular expression matching. - Re, + Re = 7, /// The `datetime` module providing date and time types. - Datetime, + Datetime = 8, + /// The `numpy` module providing ndarray operations. + /// + /// This is explicitly numbered after the pre-existing modules so adding it + /// does not alter serialized bytecode module IDs. + Numpy = 9, /// The `gc` module exposing a single `collect()` for tests. Only present /// under the `test-hooks` feature so production sandboxes never see it. /// @@ -58,7 +64,7 @@ pub(crate) enum StandardLib { /// builds. Because it's the last variant, gating it has no effect on the /// numeric discriminants of any other module. #[cfg(feature = "test-hooks")] - Gc, + Gc = 10, } impl StandardLib { @@ -73,6 +79,7 @@ impl StandardLib { StaticStrings::Math => Some(Self::Math), StaticStrings::Json => Some(Self::Json), StaticStrings::Re => Some(Self::Re), + StaticStrings::Numpy => Some(Self::Numpy), StaticStrings::Datetime => Some(Self::Datetime), #[cfg(feature = "test-hooks")] StaticStrings::Gc => Some(Self::Gc), @@ -97,6 +104,7 @@ impl StandardLib { Self::Math => math::create_module(vm), Self::Json => json::create_module(vm), Self::Re => re::create_module(vm), + Self::Numpy => numpy::create_module(vm), Self::Datetime => datetime::create_module(vm), #[cfg(feature = "test-hooks")] Self::Gc => gc::create_module(vm), @@ -110,6 +118,7 @@ pub(crate) enum ModuleFunctions { Asyncio(asyncio::AsyncioFunctions), Json(json::JsonFunctions), Math(math::MathFunctions), + Numpy(numpy::NumpyFunctions), Os(os::OsFunctions), Re(re::ReFunctions), /// `gc` module functions — only present under the `test-hooks` feature. @@ -124,6 +133,7 @@ impl fmt::Display for ModuleFunctions { Self::Asyncio(func) => write!(f, "{func}"), Self::Json(func) => write!(f, "{func}"), Self::Math(func) => write!(f, "{func}"), + Self::Numpy(func) => write!(f, "{func}"), Self::Os(func) => write!(f, "{func}"), Self::Re(func) => write!(f, "{func}"), #[cfg(feature = "test-hooks")] @@ -142,6 +152,7 @@ impl ModuleFunctions { Self::Asyncio(functions) => asyncio::call(vm.heap, functions, args), Self::Json(functions) => json::call(vm, functions, args).map(CallResult::Value), Self::Math(functions) => math::call(vm, functions, args).map(CallResult::Value), + Self::Numpy(functions) => numpy::call(vm, functions, args), Self::Os(functions) => os::call(vm, functions, args), Self::Re(functions) => re::call(vm, functions, args), #[cfg(feature = "test-hooks")] diff --git a/crates/monty/src/modules/numpy.rs b/crates/monty/src/modules/numpy.rs new file mode 100644 index 000000000..8af1df70e --- /dev/null +++ b/crates/monty/src/modules/numpy.rs @@ -0,0 +1,14077 @@ +//! Implementation of the `numpy` module. +//! +//! Provides a subset of NumPy's array creation and manipulation functions, +//! backed by Monty's built-in `NdArray` type. This module is designed to +//! make LLM-generated numpy code run transparently in the Monty sandbox. +//! +//! # Supported functions +//! +//! ## Array creation +//! - `numpy.array(data)` — create an ndarray from a list +//! - `numpy.zeros(n)` / `numpy.zeros((m, n))` — array of zeros +//! - `numpy.ones(n)` / `numpy.ones((m, n))` — array of ones +//! - `numpy.arange([start,] stop[, step])` — evenly spaced values within a range +//! - `numpy.linspace(start, stop, num)` — evenly spaced values over an interval +//! - `numpy.full(shape, fill_value)` — array filled with a constant +//! - `numpy.eye(n)` — n×n identity matrix +//! - `numpy.empty(n)` — uninitialized array (returns zeros in Monty) +//! - `numpy.copy(a)` — copy an array +//! - `numpy.zeros_like(a)` / `numpy.ones_like(a)` — array of same shape/dtype +//! +//! ## Element-wise math +//! - `numpy.abs(a)`, `numpy.sqrt(a)`, `numpy.log(a)`, `numpy.exp(a)` +//! - `numpy.sin(a)`, `numpy.cos(a)`, `numpy.tan(a)`, `numpy.log2(a)`, `numpy.log10(a)` +//! - `numpy.ceil(a)`, `numpy.floor(a)` +//! - `numpy.power(base, exp)` — element-wise power +//! - `numpy.copysign`, `numpy.frexp`, `numpy.modf`, `numpy.ldexp`, `numpy.gcd`, `numpy.lcm` +//! - `numpy.logaddexp`, `numpy.nextafter`, `numpy.spacing`, `numpy.signbit`, `numpy.sinc` +//! - `numpy.bitwise_and`, `numpy.invert`, `numpy.left_shift`, `numpy.bitwise_count` +//! - `numpy.packbits`, `numpy.unpackbits` +//! - `numpy.i0`, `numpy.bartlett`, `numpy.blackman`, `numpy.hamming`, `numpy.hanning`, `numpy.kaiser` +//! - `numpy.base_repr`, `numpy.binary_repr` +//! - `numpy.diff(a)` — discrete differences +//! - `numpy.round(a, decimals)`, `numpy.clip(a, a_min, a_max)` +//! +//! ## Aggregation +//! - `numpy.sum(a)`, `numpy.mean(a)`, `numpy.min(a)`, `numpy.max(a)`, `numpy.std(a)` +//! - `numpy.prod(a)`, `numpy.var(a)`, `numpy.median(a)` +//! - `numpy.argmin(a)`, `numpy.argmax(a)` +//! - `numpy.count_nonzero(a)` +//! +//! ## Testing & inspection +//! - `numpy.isnan(a)`, `numpy.isinf(a)`, `numpy.isfinite(a)` +//! - `numpy.array_equal(a, b)` +//! - `numpy.array2string(a)`, `numpy.array_repr(a)`, `numpy.array_str(a)` +//! - `numpy.finfo(dtype)`, `numpy.iinfo(dtype)` +//! - `numpy.dtype(dtype)`, `numpy.astype(a, dtype)` +//! - `numpy.format_float_positional(x)`, `numpy.format_float_scientific(x)` +//! - `numpy.all(a)`, `numpy.any(a)` +//! +//! ## Selection & sorting +//! - `numpy.where(condition, x, y)`, `numpy.maximum(a, b)`, `numpy.minimum(a, b)` +//! - `numpy.sort(a)`, `numpy.unique(a)` +//! +//! ## Manipulation +//! - `numpy.reshape(a, shape)`, `numpy.transpose(a)`, `numpy.concatenate(arrays)` +//! - `numpy.append(a, values)`, `numpy.vstack(arrays)`, `numpy.hstack(arrays)` +//! - `numpy.stack(arrays)`, `numpy.tile(a, reps)`, `numpy.repeat(a, repeats)` +//! - `numpy.split(a, sections_or_indices)`, `numpy.cumsum(a)`, `numpy.dot(a, b)` +//! - `numpy.take`, `numpy.compress`, `numpy.swapaxes`, `numpy.permute_dims` +//! - `numpy.matrix_transpose`, `numpy.moveaxis`, `numpy.rollaxis`, `numpy.rot90` +//! - `numpy.block`, `numpy.vecdot`, `numpy.matvec`, `numpy.vecmat`, `numpy.trapezoid`, `numpy.vander` +//! +//! ## Search & index +//! - `numpy.nonzero(a)`, `numpy.argwhere(a)` +//! - `numpy.diag_indices`, `numpy.tril_indices`, `numpy.triu_indices` +//! - `numpy.indices`, `numpy.unravel_index`, `numpy.ravel_multi_index`, `numpy.ndindex` +//! - `numpy.ndenumerate`, `numpy.nditer`, `numpy.ix_` + +use std::{ + cmp::Ordering, + collections::BTreeMap, + f64::consts::{E, PI}, + mem, + num::FpCategory, +}; + +use num_bigint::BigInt; +use smallvec::SmallVec; + +use crate::{ + args::{ArgValues, KwargsValues}, + builtins::Builtins, + bytecode::{CallResult, VM}, + defer_drop, defer_drop_mut, + exception_private::{ExcType, RunError, RunResult, SimpleException}, + heap::{ContainsHeap, Heap, HeapData, HeapGuard, HeapId, HeapReadOutput}, + heap_traits::DropWithHeap, + intern::StaticStrings, + modules::ModuleFunctions, + resource::{ResourceError, ResourceTracker, check_array_alloc_size}, + types::{ + Dict, List, LongInt, Module, MontyIter, NamedTuple, NdArray, PyTrait, Slice, Type, allocate_tuple, + ndarray::{ + NdArrayDtype, broadcast_array_data, broadcast_pair_data, broadcast_shape, nan_last_cmp, ndarray_from_list, + promote_dtype, + }, + str::{Str, allocate_string}, + }, + value::{Marker, Value}, +}; + +/// Functions exposed by the `numpy` module. +/// +/// Each variant corresponds to a module-level function like `np.array()` or `np.zeros()`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, strum::Display, serde::Serialize, serde::Deserialize)] +#[strum(serialize_all = "lowercase")] +pub(crate) enum NumpyFunctions { + /// `numpy.array(data)` — create an ndarray from a list. + Array, + /// `numpy.array2string(a)` — bare ndarray display string. + Array2string, + /// `numpy.array_repr(a)` — ndarray repr string. + ArrayRepr, + /// `numpy.array_str(a)` — bare ndarray string. + ArrayStr, + /// `numpy.fromfunction(function, shape, dtype=float)` — call a function with coordinate arrays. + Fromfunction, + /// `numpy.fromiter(iter, dtype, count=-1)` — create a 1-D numeric array from an iterable. + Fromiter, + /// `numpy.fromstring(string, dtype=float, count=-1, sep='')` — parse text into a 1-D numeric array. + Fromstring, + /// `numpy.zeros(shape)` — create an array filled with zeros. + Zeros, + /// `numpy.ones(shape)` — create an array filled with ones. + Ones, + /// `numpy.arange([start,] stop[, step])` — evenly spaced values within a range. + Arange, + /// `numpy.linspace(start, stop, num)` — evenly spaced values over an interval. + Linspace, + /// `numpy.sum(a)` — sum of array elements. + Sum, + /// `numpy.mean(a)` — mean of array elements. + Mean, + /// `numpy.min(a)` — minimum of array elements. + Min, + /// `numpy.max(a)` — maximum of array elements. + Max, + /// `numpy.abs(a)` — element-wise absolute value. + Abs, + /// `numpy.sqrt(a)` — element-wise square root. + Sqrt, + /// `numpy.log(a)` — element-wise natural logarithm. + Log, + /// `numpy.exp(a)` — element-wise exponential. + Exp, + /// `numpy.round(a, decimals)` — element-wise rounding. + Round, + /// `numpy.clip(a, a_min, a_max)` — clip values to range. + Clip, + /// `numpy.where(condition, x, y)` — conditional selection. + Where, + /// `numpy.maximum(a, b)` — element-wise maximum. + Maximum, + /// `numpy.minimum(a, b)` — element-wise minimum. + Minimum, + /// `numpy.sort(a)` — return sorted copy of array. + Sort, + /// `numpy.unique(a)` — return sorted unique elements. + Unique, + /// `numpy.unique_values(a)` — return unique values. + UniqueValues, + /// `numpy.unique_counts(a)` — return unique values and counts. + UniqueCounts, + /// `numpy.unique_inverse(a)` — return unique values and inverse indices. + UniqueInverse, + /// `numpy.unique_all(a)` — return unique values, first indices, inverse indices, and counts. + UniqueAll, + /// `numpy.concatenate(arrays)` — join arrays along axis. + Concatenate, + /// `numpy.cumsum(a)` — cumulative sum. + Cumsum, + /// `numpy.dot(a, b)` — dot product. + Dot, + /// `numpy.ceil(a)` — element-wise ceiling. + Ceil, + /// `numpy.floor(a)` — element-wise floor. + Floor, + /// `numpy.log10(a)` — element-wise base-10 logarithm. + Log10, + /// `numpy.std(a)` — standard deviation of array elements. + Std, + /// `numpy.sin(a)` — element-wise sine. + Sin, + /// `numpy.cos(a)` — element-wise cosine. + Cos, + /// `numpy.tan(a)` — element-wise tangent. + Tan, + /// `numpy.log2(a)` — element-wise base-2 logarithm. + Log2, + /// `numpy.power(a, b)` — element-wise power. + Power, + /// `numpy.diff(a)` — n-th discrete difference. + Diff, + /// `numpy.ediff1d(a)` — flattened first-order discrete difference. + Ediff1d, + /// `numpy.full(shape, fill_value)` — array filled with a constant. + Full, + /// `numpy.eye(n)` — identity matrix. + Eye, + /// `numpy.copy(a)` — copy of an array. + NpCopy, + /// `numpy.empty(n)` — uninitialized array (returns zeros in Monty). + Empty, + /// `numpy.zeros_like(a)` — array of zeros with same shape/dtype. + ZerosLike, + /// `numpy.ones_like(a)` — array of ones with same shape/dtype. + OnesLike, + /// `numpy.isnan(a)` — element-wise NaN test. + Isnan, + /// `numpy.isinf(a)` — element-wise infinity test. + Isinf, + /// `numpy.isposinf(a)` — element-wise positive infinity test. + Isposinf, + /// `numpy.isneginf(a)` — element-wise negative infinity test. + Isneginf, + /// `numpy.isfinite(a)` — element-wise finiteness test. + Isfinite, + /// `numpy.array_equal(a, b)` — true if arrays are element-wise equal. + ArrayEqual, + /// `numpy.array_equiv(a, b)` — true if arrays are equal after broadcasting. + ArrayEquiv, + /// `numpy.count_nonzero(a)` — count of non-zero elements. + CountNonzero, + /// `numpy.all(a)` — true if all elements are truthy. + All, + /// `numpy.any(a)` — true if any element is truthy. + Any, + /// `numpy.prod(a)` — product of array elements. + Prod, + /// `numpy.var(a)` — variance of array elements. + Var, + /// `numpy.median(a)` — median of array elements. + Median, + /// `numpy.argmin(a)` — index of minimum element. + Argmin, + /// `numpy.argmax(a)` — index of maximum element. + Argmax, + /// `numpy.reshape(a, shape)` — reshape an array. + Reshape, + // Note: np.flatten doesn't exist in real NumPy — use arr.flatten() method instead + /// `numpy.transpose(a)` — transpose an array. + Transpose, + /// `numpy.take(a, indices)` — gather flattened elements by index. + Take, + /// `numpy.take_along_axis(a, indices, axis)` — gather along an axis. + TakeAlongAxis, + /// `numpy.resize(a, new_shape)` — repeat flattened data into a new shape. + Resize, + /// `numpy.compress(condition, a)` — select flattened elements by condition. + Compress, + /// `numpy.swapaxes(a, axis1, axis2)` — swap two axes. + Swapaxes, + /// `numpy.permute_dims(a, axes=None)` — permute ndarray axes. + PermuteDims, + /// `numpy.matrix_transpose(a)` — swap the last two axes. + MatrixTranspose, + /// `numpy.moveaxis(a, source, destination)` — move axes to new positions. + Moveaxis, + /// `numpy.rollaxis(a, axis, start=0)` — roll one axis backward. + Rollaxis, + /// `numpy.rot90(a, k=1)` — rotate a 2-D array by quarter turns. + Rot90, + /// `numpy.choose(a, choices)` — select values from a sequence of choices. + Choose, + /// `numpy.append(a, values)` — append values to end of array. + Append, + /// `numpy.vstack(arrays)` — stack arrays vertically. + Vstack, + /// `numpy.hstack(arrays)` — stack arrays horizontally. + Hstack, + /// `numpy.dstack(arrays)` — stack arrays along depth after promoting to 3-D. + Dstack, + /// `numpy.stack(arrays)` — stack arrays along new axis. + Stack, + /// `numpy.block(arrays)` — assemble nested numeric blocks. + Block, + /// `numpy.apply_along_axis(func1d, axis, arr)` — call a function over 1-D slices. + ApplyAlongAxis, + /// `numpy.apply_over_axes(func, a, axes)` — repeatedly reduce while preserving axes. + ApplyOverAxes, + /// `numpy.piecewise(x, condlist, funclist)` — condition-list numeric selection. + Piecewise, + /// `numpy.pad(array, pad_width, mode='constant')` — materialized array padding subset. + Pad, + /// `numpy.unstack(a, axis=0)` — split an array into a tuple along an axis. + Unstack, + /// `numpy.nonzero(a)` — indices of non-zero elements. + Nonzero, + /// `numpy.argwhere(a)` — indices where elements are non-zero. + Argwhere, + /// `numpy.tile(a, reps)` — construct by repeating array. + Tile, + /// `numpy.repeat(a, repeats)` — repeat elements of array. + Repeat, + /// `numpy.split(a, indices_or_sections)` — split array into sub-arrays. + Split, + /// `numpy.add(a, b)` — element-wise addition. + Add, + /// `numpy.subtract(a, b)` — element-wise subtraction. + Subtract, + /// `numpy.multiply(a, b)` — element-wise multiplication. + Multiply, + /// `numpy.divide(a, b)` / `numpy.true_divide(a, b)` — element-wise true division. + Divide, + /// `numpy.floor_divide(a, b)` — element-wise floor division. + FloorDivide, + /// `numpy.mod(a, b)` / `numpy.remainder(a, b)` — element-wise Python modulo. + Mod, + /// `numpy.equal(a, b)` — element-wise equality comparison. + Equal, + /// `numpy.not_equal(a, b)` — element-wise inequality comparison. + NotEqual, + /// `numpy.greater(a, b)` — element-wise greater-than comparison. + Greater, + /// `numpy.greater_equal(a, b)` — element-wise greater-or-equal comparison. + GreaterEqual, + /// `numpy.less(a, b)` — element-wise less-than comparison. + Less, + /// `numpy.less_equal(a, b)` — element-wise less-or-equal comparison. + LessEqual, + /// `numpy.shape(a)` — tuple of dimensions. + Shape, + /// `numpy.size(a)` — total number of elements. + Size, + /// `numpy.ndim(a)` — number of dimensions. + Ndim, + /// `numpy.broadcast_shapes(*shapes)` — common broadcast shape. + BroadcastShapes, + /// `numpy.broadcast_to(array, shape)` — broadcast an array into a shape. + BroadcastTo, + /// `numpy.broadcast_arrays(*arrays)` — broadcast arrays to a shared shape. + BroadcastArrays, + /// `numpy.broadcast(*arrays)` — materialized iterable subset of NumPy broadcast. + Broadcast, + /// `numpy.dtype(dtype)` — normalize a supported compact dtype marker. + Dtype, + /// `numpy.astype(a, dtype)` — cast an ndarray to a supported compact dtype. + Astype, + /// `numpy.format_float_positional(x)` — format a float without scientific notation. + FormatFloatPositional, + /// `numpy.format_float_scientific(x)` — format a float with scientific notation. + FormatFloatScientific, + + // --- Phase 3: Inverse trig, hyperbolic, remaining math --- + /// `numpy.arcsin(a)` — element-wise inverse sine. + Arcsin, + /// `numpy.arccos(a)` — element-wise inverse cosine. + Arccos, + /// `numpy.arctan(a)` — element-wise inverse tangent. + Arctan, + /// `numpy.arctan2(y, x)` — element-wise two-argument arctangent. + Arctan2, + /// `numpy.angle(z, deg=False)` — real-valued phase angle. + Angle, + /// `numpy.sinh(a)` — element-wise hyperbolic sine. + Sinh, + /// `numpy.cosh(a)` — element-wise hyperbolic cosine. + Cosh, + /// `numpy.tanh(a)` — element-wise hyperbolic tangent. + Tanh, + /// `numpy.arcsinh(a)` — element-wise inverse hyperbolic sine. + Arcsinh, + /// `numpy.arccosh(a)` — element-wise inverse hyperbolic cosine. + Arccosh, + /// `numpy.arctanh(a)` — element-wise inverse hyperbolic tangent. + Arctanh, + /// `numpy.sign(a)` — element-wise sign (-1, 0, or 1). + Sign, + /// `numpy.square(a)` — element-wise square. + Square, + /// `numpy.cbrt(a)` — element-wise cube root. + Cbrt, + /// `numpy.reciprocal(a)` — element-wise 1/x. + Reciprocal, + /// `numpy.log1p(a)` — element-wise log(1 + x). + Log1p, + /// `numpy.exp2(a)` — element-wise 2^x. + Exp2, + /// `numpy.expm1(a)` — element-wise exp(x) - 1. + Expm1, + /// `numpy.deg2rad(a)` — convert degrees to radians. + Deg2rad, + /// `numpy.rad2deg(a)` — convert radians to degrees. + Rad2deg, + /// `numpy.hypot(a, b)` — element-wise hypotenuse. + Hypot, + /// `numpy.nan_to_num(a)` — replace NaN with 0 and Inf with large finite. + NanToNum, + /// `numpy.fmin(a, b)` — element-wise minimum ignoring NaN. + Fmin, + /// `numpy.fmax(a, b)` — element-wise maximum ignoring NaN. + Fmax, + /// `numpy.fmod(a, b)` — element-wise C-style modulo. + Fmod, + /// `numpy.rint(a)` — round to nearest integer. + Rint, + /// `numpy.fabs(a)` — element-wise absolute value (float result). + Fabs, + /// `numpy.positive(a)` — element-wise unary +. + Positive, + /// `numpy.negative(a)` — element-wise unary -. + Negative, + /// `numpy.copysign(a, b)` — element-wise magnitude/sign combination. + Copysign, + /// `numpy.frexp(a)` — element-wise mantissa/exponent decomposition. + Frexp, + /// `numpy.modf(a)` — element-wise fractional/integer decomposition. + Modf, + /// `numpy.ldexp(a, exp)` — element-wise multiply by powers of two. + Ldexp, + /// `numpy.gcd(a, b)` — element-wise greatest common divisor. + Gcd, + /// `numpy.lcm(a, b)` — element-wise least common multiple. + Lcm, + /// `numpy.logaddexp(a, b)` — element-wise log(exp(a) + exp(b)). + Logaddexp, + /// `numpy.logaddexp2(a, b)` — element-wise log2(2**a + 2**b). + Logaddexp2, + /// `numpy.nextafter(a, b)` — next floating point value from a toward b. + Nextafter, + /// `numpy.spacing(a)` — distance to the nearest adjacent floating value. + Spacing, + /// `numpy.signbit(a)` — element-wise sign-bit predicate. + Signbit, + /// `numpy.sinc(a)` — normalized sinc function. + Sinc, + /// `numpy.heaviside(a, h0)` — element-wise Heaviside step function. + Heaviside, + /// `numpy.trunc(a)` — truncate toward zero. + Trunc, + /// `numpy.fix(a)` — truncate toward zero. + Fix, + /// `numpy.float_power(a, b)` — element-wise floating-point exponentiation. + FloatPower, + /// `numpy.divmod(a, b)` — element-wise floor division and modulo pair. + Divmod, + /// `numpy.bitwise_and(a, b)` — element-wise integer/boolean bitwise AND. + BitwiseAnd, + /// `numpy.bitwise_or(a, b)` — element-wise integer/boolean bitwise OR. + BitwiseOr, + /// `numpy.bitwise_xor(a, b)` — element-wise integer/boolean bitwise XOR. + BitwiseXor, + /// `numpy.bitwise_not(a)` / aliases — element-wise integer/boolean inversion. + BitwiseNot, + /// `numpy.left_shift(a, b)` — element-wise integer left shift. + LeftShift, + /// `numpy.right_shift(a, b)` — element-wise integer right shift. + RightShift, + /// `numpy.bitwise_count(a)` — count set bits in each integer's absolute value. + BitwiseCount, + /// `numpy.packbits(a)` — pack non-zero values into byte-sized integers. + Packbits, + /// `numpy.unpackbits(a)` — unpack byte-sized integers into bit arrays. + Unpackbits, + /// `numpy.bartlett(M)` — Bartlett triangular window. + Bartlett, + /// `numpy.blackman(M)` — Blackman taper window. + Blackman, + /// `numpy.hamming(M)` — Hamming window. + Hamming, + /// `numpy.hanning(M)` — Hann window using NumPy's legacy spelling. + Hanning, + /// `numpy.kaiser(M, beta)` — Kaiser window. + Kaiser, + /// `numpy.i0(x)` — modified Bessel function of the first kind, order 0. + I0, + /// `numpy.base_repr(number, base=2, padding=0)` — integer base conversion string. + BaseRepr, + /// `numpy.binary_repr(num, width=None)` — integer binary conversion string. + BinaryRepr, + /// `numpy.conj(a)` — return the real-valued conjugate. + Conj, + /// `numpy.real(a)` — return the real component. + Real, + /// `numpy.real_if_close(a)` — identity for Monty's real-valued numeric subset. + RealIfClose, + /// `numpy.imag(a)` — return the imaginary component. + Imag, + /// `numpy.isreal(a)` — element-wise predicate for real values. + Isreal, + /// `numpy.isrealobj(a)` — true when the input is not complex-valued. + Isrealobj, + /// `numpy.iscomplex(a)` — element-wise predicate for complex values. + Iscomplex, + /// `numpy.iscomplexobj(a)` — true when the input has a complex dtype. + Iscomplexobj, + /// `numpy.isscalar(a)` — true for scalar values. + Isscalar, + /// `numpy.iterable(a)` — true for values accepted by Monty's iterator protocol. + Iterable, + /// `numpy.can_cast(from_, to)` — compact dtype cast predicate. + CanCast, + /// `numpy.promote_types(type1, type2)` — compact dtype promotion helper. + PromoteTypes, + /// `numpy.result_type(*arrays_and_dtypes)` — compact dtype result helper. + ResultType, + /// `numpy.common_type(*arrays)` — compact common dtype helper. + CommonType, + /// `numpy.min_scalar_type(a)` — compact scalar dtype helper. + MinScalarType, + /// `numpy.mintypecode(typechars)` — legacy dtype character helper. + Mintypecode, + /// `numpy.typename(char)` — legacy dtype character name helper. + Typename, + /// `numpy.info(object=None, ...)` — accepted no-op documentation helper. + Info, + /// `numpy.issubdtype(arg1, arg2)` — compact dtype hierarchy predicate. + Issubdtype, + /// `numpy.isdtype(dtype, kind)` — compact dtype kind predicate. + Isdtype, + /// `numpy.finfo(dtype)` — floating dtype limit metadata. + Finfo, + /// `numpy.iinfo(dtype)` — integer dtype limit metadata. + Iinfo, + /// `numpy.geterr()` — floating-point error config snapshot. + Geterr, + /// `numpy.seterr(...)` — accepted no-op floating-point error config update. + Seterr, + /// `numpy.geterrcall()` — floating-point error callback query. + Geterrcall, + /// `numpy.seterrcall(callback)` — accepted no-op error callback update. + Seterrcall, + /// `numpy.errstate(...)` — lightweight floating-point error context placeholder. + Errstate, + /// `numpy.get_printoptions()` — print config snapshot. + GetPrintoptions, + /// `numpy.set_printoptions(...)` — accepted no-op print config update. + SetPrintoptions, + /// `numpy.printoptions(...)` — lightweight print config context placeholder. + Printoptions, + /// `numpy.getbufsize()` — legacy buffer size query. + Getbufsize, + /// `numpy.setbufsize(size)` — accepted no-op buffer size update. + Setbufsize, + /// `numpy.show_runtime()` — no-host runtime display placeholder. + ShowRuntime, + /// `numpy.test()` — no-op test-runner placeholder. + Test, + /// `numpy.atleast_1d(*arrays)` — view inputs as arrays with at least one dimension. + Atleast1d, + /// `numpy.atleast_2d(*arrays)` — view inputs as arrays with at least two dimensions. + Atleast2d, + /// `numpy.atleast_3d(*arrays)` — view inputs as arrays with at least three dimensions. + Atleast3d, + /// `numpy.diag_indices(n, ndim=2)` — indices for a diagonal in an `ndim` array. + DiagIndices, + /// `numpy.diag_indices_from(arr)` — diagonal indices matching a square array. + DiagIndicesFrom, + /// `numpy.tril_indices(n, k=0, m=None)` — lower-triangle indices. + TrilIndices, + /// `numpy.tril_indices_from(arr, k=0)` — lower-triangle indices for an array. + TrilIndicesFrom, + /// `numpy.triu_indices(n, k=0, m=None)` — upper-triangle indices. + TriuIndices, + /// `numpy.triu_indices_from(arr, k=0)` — upper-triangle indices for an array. + TriuIndicesFrom, + /// `numpy.indices(dimensions)` — dense coordinate grid arrays. + Indices, + /// `numpy.unravel_index(indices, shape)` — flat indices to coordinates. + UnravelIndex, + /// `numpy.ravel_multi_index(multi_index, dims)` — coordinates to flat indices. + RavelMultiIndex, + /// `numpy.ndindex(*shape)` — row-major coordinate tuples for a shape. + Ndindex, + /// `numpy.ndenumerate(a)` — row-major `(index, value)` pairs for an array. + Ndenumerate, + /// `numpy.nditer(a)` — row-major array scalar values. + Nditer, + + // --- Phase 4: NaN-aware aggregations and statistics --- + /// `numpy.nansum(a)` — sum ignoring NaN. + Nansum, + /// `numpy.nanmean(a)` — mean ignoring NaN. + Nanmean, + /// `numpy.nanmin(a)` — min ignoring NaN. + Nanmin, + /// `numpy.nanmax(a)` — max ignoring NaN. + Nanmax, + /// `numpy.nanstd(a)` — std ignoring NaN. + Nanstd, + /// `numpy.nanvar(a)` — var ignoring NaN. + Nanvar, + /// `numpy.nanprod(a)` — product ignoring NaN. + Nanprod, + /// `numpy.nanmedian(a)` — median ignoring NaN. + Nanmedian, + /// `numpy.nanargmin(a)` — argmin ignoring NaN. + Nanargmin, + /// `numpy.nanargmax(a)` — argmax ignoring NaN. + Nanargmax, + /// `numpy.average(a)` — weighted average (simple mean without weights). + Average, + /// `numpy.percentile(a, q)` — q-th percentile. + Percentile, + /// `numpy.quantile(a, q)` — q-th quantile (q in [0,1]). + Quantile, + /// `numpy.nanpercentile(a, q)` — q-th percentile ignoring NaN values. + Nanpercentile, + /// `numpy.nanquantile(a, q)` — q-th quantile ignoring NaN values. + Nanquantile, + /// `numpy.histogram(a, bins=10)` — one-dimensional histogram counts and edges. + Histogram, + /// `numpy.histogram2d(x, y, bins=10)` — two-dimensional histogram. + Histogram2d, + /// `numpy.histogram_bin_edges(a, bins=10)` — one-dimensional histogram edges. + HistogramBinEdges, + /// `numpy.histogramdd(sample, bins=10)` — multi-dimensional histogram. + Histogramdd, + /// `numpy.ptp(a)` — peak-to-peak (max - min). + Ptp, + /// `numpy.cumprod(a)` — cumulative product. + Cumprod, + /// `numpy.nancumsum(a)` — cumulative sum ignoring NaN. + Nancumsum, + /// `numpy.nancumprod(a)` — cumulative product ignoring NaN. + Nancumprod, + + // --- Phase 5: Logical and testing functions --- + /// `numpy.logical_and(a, b)` — element-wise logical AND. + LogicalAnd, + /// `numpy.logical_or(a, b)` — element-wise logical OR. + LogicalOr, + /// `numpy.logical_not(a)` — element-wise logical NOT. + LogicalNot, + /// `numpy.logical_xor(a, b)` — element-wise logical XOR. + LogicalXor, + /// `numpy.allclose(a, b)` — true if all elements are close. + Allclose, + /// `numpy.isclose(a, b)` — element-wise closeness test. + Isclose, + /// `numpy.isin(element, test_elements)` — element membership test. + Isin, + + // --- Phase 6: Manipulation and shape --- + /// `numpy.flip(a)` — reverse array elements. + Flip, + /// `numpy.fliplr(a)` — flip left-right (2D). + Fliplr, + /// `numpy.flipud(a)` — flip up-down (2D). + Flipud, + /// `numpy.roll(a, shift)` — roll elements along axis. + Roll, + /// `numpy.expand_dims(a, axis)` — add axis. + ExpandDims, + /// `numpy.squeeze(a)` — remove length-1 axes. + Squeeze, + /// `numpy.ravel(a)` — flatten to 1D (module-level). + Ravel, + /// `numpy.delete(arr, indices)` — delete elements. + Delete, + /// `numpy.insert(arr, index, values)` — insert values. + Insert, + /// `numpy.diag(v)` — extract diagonal or create diagonal matrix. + Diag, + /// `numpy.diagflat(v, k=0)` — create a diagonal matrix from flattened input. + Diagflat, + /// `numpy.fill_diagonal(a, val)` — fill an array diagonal in place. + FillDiagonal, + /// `numpy.put(a, ind, v)` — assign flattened positions in place. + Put, + /// `numpy.put_along_axis(a, indices, values, axis)` — assign positions along an axis. + PutAlongAxis, + /// `numpy.copyto(dst, src)` — copy values into an array in place. + Copyto, + /// `numpy.putmask(a, mask, values)` — assign positions where a mask is true. + Putmask, + /// `numpy.place(a, mask, values)` — place values sequentially where a mask is true. + Place, + /// `numpy.diagonal(a)` — return diagonal of array. + Diagonal, + /// `numpy.trace(a)` — sum of diagonal elements. + Trace, + /// `numpy.flatnonzero(a)` — non-zero indices in flattened array. + Flatnonzero, + /// `numpy.asarray(a)` — convert to array without copy if possible. + Asarray, + /// `numpy.asarray_chkfinite(a)` — convert to array and reject NaN/Inf values. + AsarrayChkfinite, + /// `numpy.ascontiguousarray(a)` — Monty ndarray conversion with C-order semantics. + Ascontiguousarray, + /// `numpy.asfortranarray(a)` — Monty ndarray conversion with Fortran-order compatibility. + Asfortranarray, + /// `numpy.require(a)` — Monty ndarray conversion ignoring unsupported layout flags. + Require, + /// `numpy.ix_(*args)` — construct open mesh index arrays from 1-D sequences. + Ix, + /// `numpy.mask_indices(n, mask_func, k=0)` — indices selected by triangular masks. + MaskIndices, + /// `numpy.isfortran(a)` — true for Fortran-contiguous arrays. + Isfortran, + /// `numpy.may_share_memory(a, b)` — conservative overlap predicate. + MayShareMemory, + /// `numpy.shares_memory(a, b)` — exact overlap predicate for Monty's ndarray refs. + SharesMemory, + /// `numpy.column_stack(arrays)` — stack 1D arrays as columns. + ColumnStack, + /// `numpy.row_stack(arrays)` — alias for vstack. + RowStack, + /// `numpy.hsplit(a, n)` — horizontal split. + Hsplit, + /// `numpy.vsplit(a, n)` — vertical split. + Vsplit, + /// `numpy.dsplit(a, n)` — depth split. + Dsplit, + /// `numpy.array_split(a, n)` — split into possibly unequal parts. + ArraySplit, + /// `numpy.full_like(a, fill_value)` — array of same shape filled with value. + FullLike, + /// `numpy.empty_like(a)` — uninitialized array of same shape. + EmptyLike, + + // --- Phase 7: Sorting, searching, set operations --- + /// `numpy.argsort(a)` — module-level argsort. + ArgsortMod, + /// `numpy.argpartition(a, kth)` — indirect partition indices for 1-D arrays. + Argpartition, + /// `numpy.partition(a, kth)` — partition values for 1-D arrays. + Partition, + /// `numpy.lexsort(keys)` — indirect stable sort over 1-D key arrays. + Lexsort, + /// `numpy.cov(m)` — covariance matrix for 1-D or row-wise 2-D input. + Cov, + /// `numpy.corrcoef(x)` — correlation matrix for 1-D or row-wise 2-D input. + Corrcoef, + /// `numpy.searchsorted(a, v)` — find insertion points. + Searchsorted, + /// `numpy.extract(condition, arr)` — extract elements by condition. + Extract, + /// `numpy.trim_zeros(filt, trim='fb')` — trim leading and/or trailing zeros. + TrimZeros, + /// `numpy.unwrap(p, discont=None)` — unwrap phase jumps in a 1-D sequence. + Unwrap, + /// `numpy.intersect1d(a, b)` — sorted unique intersection. + Intersect1d, + /// `numpy.union1d(a, b)` — sorted unique union. + Union1d, + /// `numpy.setdiff1d(a, b)` — elements in a not in b. + Setdiff1d, + /// `numpy.setxor1d(a, b)` — elements in either but not both. + Setxor1d, + /// `numpy.bincount(a)` — count occurrences of each non-negative int. + Bincount, + /// `numpy.digitize(x, bins)` — indices of bins. + Digitize, + + // --- Phase 8: Linear algebra --- + /// `numpy.matmul(a, b)` — matrix multiplication. + Matmul, + /// `numpy.inner(a, b)` — inner product. + Inner, + /// `numpy.outer(a, b)` — outer product. + Outer, + /// `numpy.vdot(a, b)` — vector dot product (flattens first). + Vdot, + /// `numpy.vecdot(a, b)` — vector dot product. + Vecdot, + /// `numpy.matvec(a, x)` — matrix-vector multiplication. + Matvec, + /// `numpy.vecmat(x, a)` — vector-matrix multiplication. + Vecmat, + /// `numpy.cross(a, b)` — cross product (3-element vectors). + Cross, + /// `numpy.kron(a, b)` — Kronecker product. + Kron, + /// `numpy.tensordot(a, b, axes=2)` — generalized real-valued tensor contraction. + Tensordot, + /// `numpy.einsum(subscripts, *operands)` — explicit-subscript real-valued contraction. + Einsum, + /// `numpy.einsum_path(subscripts, *operands)` — simple path for supported contractions. + EinsumPath, + /// `numpy.trapezoid(y, x=None, dx=1.0)` — composite trapezoidal integral. + Trapezoid, + /// `numpy.vander(x, N=None, increasing=False)` — Vandermonde matrix. + Vander, + /// `numpy.poly(seq_of_zeros)` — construct coefficients from real roots. + Poly, + /// `numpy.polyadd(a, b)` — add polynomial coefficient arrays. + Polyadd, + /// `numpy.polysub(a, b)` — subtract polynomial coefficient arrays. + Polysub, + /// `numpy.polymul(a, b)` — multiply polynomial coefficient arrays. + Polymul, + /// `numpy.polydiv(u, v)` — divide polynomial coefficient arrays. + Polydiv, + /// `numpy.polyint(p, m=1)` — integrate polynomial coefficients. + Polyint, + /// `numpy.polyder(p, m=1)` — differentiate polynomial coefficients. + Polyder, + /// `numpy.polyval(p, x)` — evaluate polynomial coefficients. + Polyval, + + // --- Phase 10: Additional creation functions --- + /// `numpy.logspace(start, stop, num)` — log-spaced values. + Logspace, + /// `numpy.geomspace(start, stop, num)` — geometrically spaced values. + Geomspace, + /// `numpy.tri(N)` — triangular array. + Tri, + /// `numpy.tril(m)` — lower triangle. + Tril, + /// `numpy.triu(m)` — upper triangle. + Triu, + /// `numpy.identity(n)` — identity matrix (alias for eye). + Identity, + /// `numpy.meshgrid(*xi)` — coordinate matrices from vectors. + Meshgrid, + /// `numpy.gradient(f)` — numerical gradient. + Gradient, + /// `numpy.convolve(a, v)` — discrete linear convolution. + Convolve, + /// `numpy.correlate(a, v)` — cross-correlation. + Correlate, + /// `numpy.interp(x, xp, fp)` — 1D linear interpolation. + Interp, + /// `numpy.select(condlist, choicelist)` — conditional selection. + Select, +} + +/// Creates the `numpy` module and allocates it on the heap. +/// +/// Registers all numpy functions as module attributes. +pub fn create_module(vm: &mut VM<'_, impl ResourceTracker>) -> Result { + let mut module = Module::new(StaticStrings::Numpy); + + for (name, func) in NUMPY_FUNCTIONS { + module.set_attr(*name, Value::ModuleFunction(ModuleFunctions::Numpy(*func)), vm); + } + + // Module-level constants + module.set_attr(StaticStrings::Pi, Value::Float(PI), vm); + module.set_attr(StaticStrings::MathE, Value::Float(E), vm); + module.set_attr(StaticStrings::MathInf, Value::Float(f64::INFINITY), vm); + module.set_attr(StaticStrings::MathNan, Value::Float(f64::NAN), vm); + module.set_attr(StaticStrings::Newaxis, Value::None, vm); + module.set_attr( + StaticStrings::NpLittleEndian, + Value::Bool(cfg!(target_endian = "little")), + vm, + ); + module.set_attr(StaticStrings::NpEulerGamma, Value::Float(0.577_215_664_901_532_9), vm); + module.set_attr(StaticStrings::NpFalseScalar, Value::Bool(false), vm); + module.set_attr(StaticStrings::NpTrueScalar, Value::Bool(true), vm); + module.set_attr( + StaticStrings::NpNdarray, + Value::Builtin(Builtins::Type(Type::NdArray)), + vm, + ); + module.set_attr( + StaticStrings::NpFlatiter, + Value::Builtin(Builtins::Type(Type::FlatIter)), + vm, + ); + module.set_attr(StaticStrings::NpUfunc, Value::Builtin(Builtins::Type(Type::Ufunc)), vm); + + // Dtype type objects — stored as interned strings that astype() recognizes. + // These allow `arr.astype(np.float64)` to work alongside `arr.astype('float64')`. + for (name, target) in NUMPY_DTYPE_ALIASES { + module.set_attr(*name, Value::InternString((*target).into()), vm); + } + for (name, target) in NUMPY_MARKER_ONLY_DTYPE_ALIASES { + module.set_attr(*name, Value::InternString((*target).into()), vm); + } + module.set_attr(StaticStrings::NpScalarType, numpy_scalar_type_tuple(vm)?, vm); + module.set_attr( + StaticStrings::NpInteger, + Value::InternString(StaticStrings::NpIntegerCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpFloating, + Value::InternString(StaticStrings::NpFloatingCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpInexact, + Value::InternString(StaticStrings::NpInexactCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpGeneric, + Value::InternString(StaticStrings::NpGenericCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpNumber, + Value::InternString(StaticStrings::NpNumberCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpSignedInteger, + Value::InternString(StaticStrings::NpSignedIntegerCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpUnsignedInteger, + Value::InternString(StaticStrings::NpUnsignedIntegerCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpComplexFloating, + Value::InternString(StaticStrings::NpComplexFloatingCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpFlexible, + Value::InternString(StaticStrings::NpFlexibleCategoryMarker.into()), + vm, + ); + module.set_attr( + StaticStrings::NpCharacter, + Value::InternString(StaticStrings::NpCharacterCategoryMarker.into()), + vm, + ); + for name in [ + StaticStrings::NpIndexExp, + StaticStrings::NpSIndex, + StaticStrings::NpMgrid, + StaticStrings::NpOgrid, + StaticStrings::NpRIndex, + StaticStrings::NpCIndex, + ] { + module.set_attr(name, Value::Marker(Marker(name)), vm); + } + module.set_attr(StaticStrings::NpTypecodes, numpy_typecodes_dict(vm)?, vm); + module.set_attr(StaticStrings::NpSctypeDict, numpy_sctype_dict(vm)?, vm); + + vm.heap.allocate(HeapData::Module(module)) +} + +impl NumpyFunctions { + /// Returns whether this module function behaves like a NumPy ufunc. + /// + /// Monty does not expose a separate callable object wrapper for ufuncs; the + /// implemented ufunc surface is represented by the real module functions + /// that already perform elementwise broadcasting and dtype-aware scalar + /// conversion. This predicate backs `isinstance(np.add, np.ufunc)` without + /// claiming non-ufunc helpers such as `np.array` have ufunc semantics. + pub(crate) const fn is_ufunc_like(self) -> bool { + matches!( + self, + Self::Add + | Self::Subtract + | Self::Multiply + | Self::Divide + | Self::FloorDivide + | Self::Mod + | Self::Equal + | Self::NotEqual + | Self::Greater + | Self::GreaterEqual + | Self::Less + | Self::LessEqual + | Self::Abs + | Self::Sqrt + | Self::Log + | Self::Exp + | Self::Maximum + | Self::Minimum + | Self::Ceil + | Self::Floor + | Self::Log10 + | Self::Sin + | Self::Cos + | Self::Tan + | Self::Log2 + | Self::Power + | Self::Isnan + | Self::Isinf + | Self::Isposinf + | Self::Isneginf + | Self::Isfinite + | Self::Arcsin + | Self::Arccos + | Self::Arctan + | Self::Arctan2 + | Self::Sinh + | Self::Cosh + | Self::Tanh + | Self::Arcsinh + | Self::Arccosh + | Self::Arctanh + | Self::Sign + | Self::Square + | Self::Cbrt + | Self::Reciprocal + | Self::Log1p + | Self::Exp2 + | Self::Expm1 + | Self::Deg2rad + | Self::Rad2deg + | Self::Hypot + | Self::Fmin + | Self::Fmax + | Self::Fmod + | Self::Rint + | Self::Fabs + | Self::Positive + | Self::Negative + | Self::Copysign + | Self::Frexp + | Self::Modf + | Self::Ldexp + | Self::Gcd + | Self::Lcm + | Self::Logaddexp + | Self::Logaddexp2 + | Self::Nextafter + | Self::Spacing + | Self::Signbit + | Self::Heaviside + | Self::Trunc + | Self::Fix + | Self::FloatPower + | Self::Divmod + | Self::BitwiseAnd + | Self::BitwiseOr + | Self::BitwiseXor + | Self::BitwiseNot + | Self::LeftShift + | Self::RightShift + | Self::BitwiseCount + | Self::LogicalAnd + | Self::LogicalOr + | Self::LogicalNot + | Self::LogicalXor + ) + } +} + +/// Handles subscription on NumPy compatibility marker objects such as `np.s_`. +/// +/// NumPy exposes several index-trick attributes as stateful Python objects with +/// custom `__getitem__` methods. Monty represents those attributes as immediate +/// marker values so we can support the useful subscription behavior without +/// adding a new heap object kind or touching the heap safety boundary. +pub(crate) fn numpy_marker_getitem( + marker: StaticStrings, + key: &Value, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult> { + let value = match marker { + StaticStrings::NpIndexExp | StaticStrings::NpSIndex => key.clone_with_heap(vm), + StaticStrings::NpMgrid => numpy_mgrid_getitem(key, vm)?, + StaticStrings::NpOgrid => numpy_ogrid_getitem(key, vm)?, + StaticStrings::NpRIndex => numpy_r_index_getitem(key, vm)?, + StaticStrings::NpCIndex => numpy_c_index_getitem(key, vm)?, + _ => return Ok(None), + }; + Ok(Some(value)) +} + +/// Owned numeric item parsed from an index-trick subscription key. +struct IndexTrickInput { + /// Flattened numeric data represented by this key item. + data: Vec, + /// Shape associated with `data`, using Monty's row-major ndarray layout. + shape: Vec, + /// Numeric dtype inferred for this item. + dtype: NdArrayDtype, + /// Whether this item came from slice syntax such as `0:3`. + is_slice_range: bool, +} + +/// `numpy.mgrid[...]` — dense coordinate grids from numeric slice syntax. +fn numpy_mgrid_getitem(key: &Value, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let ranges = collect_grid_ranges(key, "numpy.mgrid", vm)?; + if ranges.len() == 1 { + let range = ranges.into_iter().next().expect("one range"); + return allocate_ndarray_from_data(range.data, range.shape, range.dtype, vm); + } + + let axis_lengths: Vec = ranges.iter().map(|range| range.data.len()).collect(); + let grid_len = checked_shape_product(&axis_lengths, "numpy.mgrid")?; + let ndim = ranges.len(); + let total = grid_len + .checked_mul(ndim) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.mgrid() result is too large"))?; + check_array_alloc_size(total, vm.heap.tracker())?; + + let mut shape = Vec::with_capacity(ndim + 1); + shape.push(ndim); + shape.extend(axis_lengths.iter().copied()); + + let strides = row_major_strides_for_shape(&axis_lengths); + let mut data = Vec::with_capacity(total); + for axis in 0..ndim { + let values = &ranges[axis].data; + for flat_index in 0..grid_len { + let coord = if axis_lengths[axis] == 0 { + 0 + } else { + (flat_index / strides[axis]) % axis_lengths[axis] + }; + data.push(values[coord]); + } + } + allocate_ndarray_from_data(data, shape, NdArrayDtype::Int64, vm) +} + +/// `numpy.ogrid[...]` — sparse/open coordinate grids from numeric slice syntax. +fn numpy_ogrid_getitem(key: &Value, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let ranges = collect_grid_ranges(key, "numpy.ogrid", vm)?; + if ranges.len() == 1 { + let range = ranges.into_iter().next().expect("one range"); + return allocate_ndarray_from_data(range.data, range.shape, range.dtype, vm); + } + + let ndim = ranges.len(); + let mut values: SmallVec<[Value; 3]> = SmallVec::new(); + for (axis, range) in ranges.into_iter().enumerate() { + check_array_alloc_size(range.data.len(), vm.heap.tracker())?; + let mut shape = vec![1; ndim]; + shape[axis] = range.data.len(); + values.push(allocate_ndarray_from_data(range.data, shape, range.dtype, vm)?); + } + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// `numpy.r_[...]` — concatenate numeric slices/scalars/arrays into one vector. +fn numpy_r_index_getitem(key: &Value, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let inputs = collect_index_trick_inputs(key, "numpy.r_", vm)?; + let total = inputs + .iter() + .try_fold(0usize, |acc, input| acc.checked_add(input.data.len())) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.r_() result is too large"))?; + check_array_alloc_size(total, vm.heap.tracker())?; + + let dtype = inputs + .iter() + .fold(NdArrayDtype::Bool, |dtype, input| promote_dtype(dtype, input.dtype)); + let mut data = Vec::with_capacity(total); + for input in inputs { + data.extend(input.data); + } + allocate_ndarray_from_data(data, vec![total], dtype, vm) +} + +/// `numpy.c_[...]` — column-stack numeric slices/scalars/arrays. +fn numpy_c_index_getitem(key: &Value, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let inputs = collect_index_trick_inputs(key, "numpy.c_", vm)?; + let columns = inputs + .into_iter() + .map(column_block_from_input) + .collect::>>()?; + let Some(first) = columns.first() else { + return allocate_ndarray_from_data(Vec::new(), vec![0, 0], NdArrayDtype::Float64, vm); + }; + let rows = first.rows; + let total_cols = columns + .iter() + .try_fold(0usize, |acc, block| { + if block.rows == rows { + acc.checked_add(block.cols) + } else { + None + } + }) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.c_() input dimensions must match"))?; + let total = rows + .checked_mul(total_cols) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.c_() result is too large"))?; + check_array_alloc_size(total, vm.heap.tracker())?; + + let dtype = columns + .iter() + .fold(NdArrayDtype::Bool, |dtype, block| promote_dtype(dtype, block.dtype)); + let mut data = Vec::with_capacity(total); + for row in 0..rows { + for block in &columns { + let start = row * block.cols; + data.extend_from_slice(&block.data[start..start + block.cols]); + } + } + allocate_ndarray_from_data(data, vec![rows, total_cols], dtype, vm) +} + +/// A row-major block used by `numpy.c_` column stacking. +struct ColumnBlock { + /// Row-major data for this block. + data: Vec, + /// Number of rows in the block. + rows: usize, + /// Number of columns contributed by the block. + cols: usize, + /// Numeric dtype inferred for the block. + dtype: NdArrayDtype, +} + +/// Converts one parsed index-trick item into a 2-D block for `numpy.c_`. +fn column_block_from_input(input: IndexTrickInput) -> RunResult { + match input.shape.as_slice() { + [] => Ok(ColumnBlock { + data: input.data, + rows: 1, + cols: 1, + dtype: input.dtype, + }), + [rows] => Ok(ColumnBlock { + data: input.data, + rows: *rows, + cols: 1, + dtype: input.dtype, + }), + [rows, cols] => Ok(ColumnBlock { + data: input.data, + rows: *rows, + cols: *cols, + dtype: input.dtype, + }), + _ => Err(ExcType::type_error( + "numpy.c_() only supports scalar, 1-D, or 2-D inputs", + )), + } +} + +/// Collects the key items supplied to an index-trick subscription. +fn collect_index_trick_inputs( + key: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + if let Value::Ref(id) = key + && let HeapData::Tuple(tuple) = vm.heap.get(*id) + { + return tuple + .as_slice() + .iter() + .map(|item| index_trick_input_from_value(item, name, vm)) + .collect(); + } + Ok(vec![index_trick_input_from_value(key, name, vm)?]) +} + +/// Collects only slice-derived ranges for `mgrid` and `ogrid`. +fn collect_grid_ranges(key: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult> { + let ranges = collect_index_trick_inputs(key, name, vm)?; + if ranges.iter().all(|range| range.is_slice_range) { + Ok(ranges) + } else { + Err(ExcType::type_error(format!("{name} requires integer slice ranges"))) + } +} + +/// Parses one item from a NumPy index-trick subscription key. +fn index_trick_input_from_value( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + if let Value::Ref(id) = value + && let HeapData::Slice(slice) = vm.heap.get(*id) + { + return range_from_index_slice(slice, name, vm.heap.tracker()); + } + if let Ok(arr) = ndarray_from_value(value, name, vm) { + Ok(IndexTrickInput { + data: arr.data().to_vec(), + shape: arr.shape().to_vec(), + dtype: arr.dtype(), + is_slice_range: false, + }) + } else { + let (value, dtype) = numeric_scalar_info(value, name, vm)?; + Ok(IndexTrickInput { + data: vec![value], + shape: Vec::new(), + dtype, + is_slice_range: false, + }) + } +} + +/// Converts a Python slice into the integer range used by NumPy index tricks. +fn range_from_index_slice(slice: &Slice, name: &str, tracker: &impl ResourceTracker) -> RunResult { + let start = slice.start.unwrap_or(0); + let stop = slice + .stop + .ok_or_else(|| ExcType::type_error(format!("{name} slice stop is required")))?; + let step = slice.step.unwrap_or(1); + if step == 0 { + return Err(ExcType::value_error_slice_step_zero()); + } + let len = index_slice_len(start, stop, step)?; + check_array_alloc_size(len, tracker)?; + let mut data = Vec::with_capacity(len); + let mut value = start; + if step > 0 { + while value < stop { + data.push(i64_to_f64(value)); + value = value.saturating_add(step); + } + } else { + while value > stop { + data.push(i64_to_f64(value)); + value = value.saturating_add(step); + } + } + Ok(IndexTrickInput { + data, + shape: vec![len], + dtype: NdArrayDtype::Int64, + is_slice_range: true, + }) +} + +/// Computes the number of elements produced by an index-trick integer slice. +fn index_slice_len(start: i64, stop: i64, step: i64) -> RunResult { + let start = i128::from(start); + let stop = i128::from(stop); + let step = i128::from(step); + let len = if step > 0 { + if start >= stop { + 0 + } else { + ((stop - start - 1) / step) + 1 + } + } else if start <= stop { + 0 + } else { + ((start - stop - 1) / -step) + 1 + }; + usize::try_from(len).map_err(|_| SimpleException::new_msg(ExcType::ValueError, "index range is too large").into()) +} + +/// Returns row-major strides for coordinate generation. +fn row_major_strides_for_shape(shape: &[usize]) -> Vec { + let mut strides = vec![1usize; shape.len()]; + for axis in (0..shape.len()).rev() { + if axis + 1 < shape.len() { + strides[axis] = strides[axis + 1].saturating_mul(shape[axis + 1]); + } + } + strides +} + +/// Allocates an ndarray value from already-owned numeric data. +fn allocate_ndarray_from_data( + data: Vec, + shape: Vec, + dtype: NdArrayDtype, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + Ok(Value::Ref( + vm.heap.allocate(HeapData::NdArray(NdArray::new(data, shape, dtype)))?, + )) +} + +/// Builds NumPy's legacy `typecodes` dictionary for code that inspects dtype families. +fn numpy_typecodes_dict(vm: &mut VM<'_, impl ResourceTracker>) -> Result { + let pairs = [ + ("Character", "c"), + ("Integer", "bhilqnp"), + ("UnsignedInteger", "BHILQNP"), + ("Float", "efdg"), + ("Complex", "FDG"), + ("AllInteger", "bBhHiIlLqQnNpP"), + ("AllFloat", "efdgFDG"), + ("Datetime", "Mm"), + ("All", "?bhilqnpBHILQNPefdgFDGSUVOMm"), + ] + .into_iter() + .map(|(key, value)| { + let key = Value::Ref(vm.heap.allocate(HeapData::Str(Str::new(key.to_string())))?); + let value = Value::Ref(vm.heap.allocate(HeapData::Str(Str::new(value.to_string())))?); + Ok((key, value)) + }) + .collect::, ResourceError>>()?; + let dict = Dict::from_pairs(pairs, vm).expect("numpy.typecodes uses hashable string literal keys"); + Ok(Value::Ref(vm.heap.allocate(HeapData::Dict(dict))?)) +} + +/// Builds NumPy's legacy scalar-type dictionary for Monty's compact dtype aliases. +fn numpy_sctype_dict(vm: &mut VM<'_, impl ResourceTracker>) -> Result { + let pairs = NUMPY_SCTYPE_DICT + .iter() + .map(|(key, value)| { + let key = Value::Ref(vm.heap.allocate(HeapData::Str(Str::new((*key).to_string())))?); + let value = Value::InternString((*value).into()); + Ok((key, value)) + }) + .collect::, ResourceError>>()?; + let dict = Dict::from_pairs(pairs, vm).expect("numpy.sctypeDict uses hashable string literal keys"); + Ok(Value::Ref(vm.heap.allocate(HeapData::Dict(dict))?)) +} + +/// Builds NumPy's `ScalarType` tuple for scalar constructors Monty can expose safely. +/// +/// CPython NumPy includes every NumPy scalar class here. Monty only has real +/// Python scalar constructors plus metadata-only dtype markers, so this tuple +/// intentionally contains only actual callable type objects that `isinstance` +/// can evaluate without pretending marker strings are runtime scalar classes. +fn numpy_scalar_type_tuple(vm: &VM<'_, impl ResourceTracker>) -> Result { + allocate_tuple( + SmallVec::from_vec(vec![ + Value::Builtin(Builtins::Type(Type::Int)), + Value::Builtin(Builtins::Type(Type::Float)), + Value::Builtin(Builtins::Type(Type::Bool)), + Value::Builtin(Builtins::Type(Type::Bytes)), + Value::Builtin(Builtins::Type(Type::Str)), + ]), + vm.heap, + ) +} + +/// NumPy dtype attributes supported by Monty's compact numeric ndarray model. +/// +/// Many NumPy dtype names are aliases for platform-sized or narrower integer +/// and floating point types. Monty currently stores only bool, int64, and +/// float64 arrays, but narrow integer aliases still keep their public marker +/// names so hierarchy predicates such as `issubdtype(np.uint64, +/// np.unsignedinteger)` can preserve NumPy's type-family distinction. +const NUMPY_DTYPE_ALIASES: &[(StaticStrings, StaticStrings)] = &[ + (StaticStrings::NpFloat64, StaticStrings::NpFloat64), + (StaticStrings::NpDouble, StaticStrings::NpFloat64), + (StaticStrings::NpLongdouble, StaticStrings::NpFloat64), + (StaticStrings::NpFloat32, StaticStrings::NpFloat32), + (StaticStrings::NpFloat16, StaticStrings::NpFloat32), + (StaticStrings::NpHalf, StaticStrings::NpFloat32), + (StaticStrings::NpSingle, StaticStrings::NpFloat32), + (StaticStrings::NpInt64, StaticStrings::NpInt64), + (StaticStrings::NpInt_, StaticStrings::NpInt64), + (StaticStrings::NpIntp, StaticStrings::NpInt64), + (StaticStrings::NpLong, StaticStrings::NpInt64), + (StaticStrings::NpLonglong, StaticStrings::NpInt64), + (StaticStrings::NpByte, StaticStrings::NpByte), + (StaticStrings::NpShort, StaticStrings::NpShort), + (StaticStrings::NpInt8, StaticStrings::NpInt8), + (StaticStrings::NpInt16, StaticStrings::NpInt16), + (StaticStrings::NpUint, StaticStrings::NpUint), + (StaticStrings::NpUintp, StaticStrings::NpUintp), + (StaticStrings::NpUbyte, StaticStrings::NpUbyte), + (StaticStrings::NpUshort, StaticStrings::NpUshort), + (StaticStrings::NpUint8, StaticStrings::NpUint8), + (StaticStrings::NpUint16, StaticStrings::NpUint16), + (StaticStrings::NpUint32, StaticStrings::NpUint32), + (StaticStrings::NpUint64, StaticStrings::NpUint64), + (StaticStrings::NpUlong, StaticStrings::NpUlong), + (StaticStrings::NpUlonglong, StaticStrings::NpUlonglong), + (StaticStrings::NpInt32, StaticStrings::NpInt32), + (StaticStrings::NpIntc, StaticStrings::NpInt32), + (StaticStrings::NpUintc, StaticStrings::NpUintc), + (StaticStrings::NpBool_, StaticStrings::NpBool_), + (StaticStrings::NpBool, StaticStrings::NpBool_), +]; + +/// Metadata-only dtype attributes that do not imply ndarray storage support. +/// +/// These public names let dtype predicates and promotion helpers recognize +/// NumPy scalar families such as complex, string, object, and datetime. They +/// are deliberately kept out of [`NUMPY_DTYPE_ALIASES`] so constructors and +/// `astype()` continue to reject storage dtypes Monty does not implement. +const NUMPY_MARKER_ONLY_DTYPE_ALIASES: &[(StaticStrings, StaticStrings)] = &[ + (StaticStrings::NpComplex64, StaticStrings::NpComplex64), + (StaticStrings::NpComplex128, StaticStrings::NpComplex128), + (StaticStrings::NpCdouble, StaticStrings::NpComplex128), + (StaticStrings::NpCsingle, StaticStrings::NpComplex64), + (StaticStrings::NpClongdouble, StaticStrings::NpClongdouble), + (StaticStrings::NpStr_, StaticStrings::NpStr_), + (StaticStrings::NpBytes_, StaticStrings::NpBytes_), + (StaticStrings::NpVoid, StaticStrings::NpVoid), + (StaticStrings::NpObject_, StaticStrings::NpObject_), + (StaticStrings::NpDatetime64, StaticStrings::NpDatetime64), + (StaticStrings::NpTimedelta64, StaticStrings::NpTimedelta64), +]; + +/// Name-to-dtype aliases exposed through `numpy.sctypeDict`. +const NUMPY_SCTYPE_DICT: &[(&str, StaticStrings)] = &[ + ("bool", StaticStrings::NpBool_), + ("bool_", StaticStrings::NpBool_), + ("int", StaticStrings::NpInt64), + ("int_", StaticStrings::NpInt64), + ("int8", StaticStrings::NpInt64), + ("int16", StaticStrings::NpInt64), + ("int32", StaticStrings::NpInt32), + ("int64", StaticStrings::NpInt64), + ("uint", StaticStrings::NpInt64), + ("uint8", StaticStrings::NpInt64), + ("uint16", StaticStrings::NpInt64), + ("uint32", StaticStrings::NpInt64), + ("uint64", StaticStrings::NpInt64), + ("float", StaticStrings::NpFloat64), + ("float16", StaticStrings::NpFloat32), + ("float32", StaticStrings::NpFloat32), + ("float64", StaticStrings::NpFloat64), + ("half", StaticStrings::NpFloat32), + ("single", StaticStrings::NpFloat32), + ("double", StaticStrings::NpFloat64), + ("longdouble", StaticStrings::NpFloat64), + ("byte", StaticStrings::NpInt64), + ("short", StaticStrings::NpInt64), + ("long", StaticStrings::NpInt64), + ("longlong", StaticStrings::NpInt64), + ("ubyte", StaticStrings::NpInt64), + ("ushort", StaticStrings::NpInt64), + ("ulong", StaticStrings::NpInt64), + ("ulonglong", StaticStrings::NpInt64), + ("complex64", StaticStrings::NpComplex64), + ("complex128", StaticStrings::NpComplex128), + ("cdouble", StaticStrings::NpComplex128), + ("csingle", StaticStrings::NpComplex64), + ("clongdouble", StaticStrings::NpClongdouble), + ("str", StaticStrings::NpStr_), + ("str_", StaticStrings::NpStr_), + ("bytes", StaticStrings::NpBytes_), + ("bytes_", StaticStrings::NpBytes_), + ("void", StaticStrings::NpVoid), + ("object", StaticStrings::NpObject_), + ("object_", StaticStrings::NpObject_), + ("datetime64", StaticStrings::NpDatetime64), + ("timedelta64", StaticStrings::NpTimedelta64), +]; + +/// Static mapping of attribute names to numpy functions for module creation. +const NUMPY_FUNCTIONS: &[(StaticStrings, NumpyFunctions)] = &[ + (StaticStrings::NpArray, NumpyFunctions::Array), + (StaticStrings::NpArray2string, NumpyFunctions::Array2string), + (StaticStrings::NpArrayRepr, NumpyFunctions::ArrayRepr), + (StaticStrings::NpArrayStr, NumpyFunctions::ArrayStr), + (StaticStrings::NpFromfunction, NumpyFunctions::Fromfunction), + (StaticStrings::NpFromiter, NumpyFunctions::Fromiter), + (StaticStrings::NpFromstring, NumpyFunctions::Fromstring), + (StaticStrings::NpAsanyarray, NumpyFunctions::Asarray), + (StaticStrings::NpZeros, NumpyFunctions::Zeros), + (StaticStrings::NpOnes, NumpyFunctions::Ones), + (StaticStrings::Add, NumpyFunctions::Add), + (StaticStrings::NpSubtract, NumpyFunctions::Subtract), + (StaticStrings::NpMultiply, NumpyFunctions::Multiply), + (StaticStrings::NpDivide, NumpyFunctions::Divide), + (StaticStrings::NpTrueDivide, NumpyFunctions::Divide), // alias + (StaticStrings::NpFloorDivide, NumpyFunctions::FloorDivide), + (StaticStrings::NpMod, NumpyFunctions::Mod), + (StaticStrings::Remainder, NumpyFunctions::Mod), // alias + (StaticStrings::NpEqual, NumpyFunctions::Equal), + (StaticStrings::NpNotEqual, NumpyFunctions::NotEqual), + (StaticStrings::NpGreater, NumpyFunctions::Greater), + (StaticStrings::NpGreaterEqual, NumpyFunctions::GreaterEqual), + (StaticStrings::NpLess, NumpyFunctions::Less), + (StaticStrings::NpLessEqual, NumpyFunctions::LessEqual), + (StaticStrings::NpArange, NumpyFunctions::Arange), + (StaticStrings::NpLinspace, NumpyFunctions::Linspace), + (StaticStrings::NpSum, NumpyFunctions::Sum), + (StaticStrings::Mean, NumpyFunctions::Mean), + (StaticStrings::NpMin, NumpyFunctions::Min), + (StaticStrings::NpAmin, NumpyFunctions::Min), // alias + (StaticStrings::NpMax, NumpyFunctions::Max), + (StaticStrings::NpAmax, NumpyFunctions::Max), // alias + (StaticStrings::Abs, NumpyFunctions::Abs), + (StaticStrings::Absolute, NumpyFunctions::Abs), // alias + (StaticStrings::Sqrt, NumpyFunctions::Sqrt), + (StaticStrings::Log, NumpyFunctions::Log), + (StaticStrings::Exp, NumpyFunctions::Exp), + (StaticStrings::Round, NumpyFunctions::Round), + (StaticStrings::NpAround, NumpyFunctions::Round), // alias + (StaticStrings::Clip, NumpyFunctions::Clip), + (StaticStrings::NpWhere, NumpyFunctions::Where), + (StaticStrings::Maximum, NumpyFunctions::Maximum), + (StaticStrings::Minimum, NumpyFunctions::Minimum), + (StaticStrings::Sort, NumpyFunctions::Sort), + (StaticStrings::Unique, NumpyFunctions::Unique), + (StaticStrings::NpUniqueValues, NumpyFunctions::UniqueValues), + (StaticStrings::NpUniqueCounts, NumpyFunctions::UniqueCounts), + (StaticStrings::NpUniqueInverse, NumpyFunctions::UniqueInverse), + (StaticStrings::NpUniqueAll, NumpyFunctions::UniqueAll), + (StaticStrings::Concatenate, NumpyFunctions::Concatenate), + (StaticStrings::NpConcat, NumpyFunctions::Concatenate), // alias + (StaticStrings::Cumsum, NumpyFunctions::Cumsum), + (StaticStrings::NpCumulativeSum, NumpyFunctions::Cumsum), // alias + (StaticStrings::Dot, NumpyFunctions::Dot), + (StaticStrings::Ceil, NumpyFunctions::Ceil), + (StaticStrings::Floor, NumpyFunctions::Floor), + (StaticStrings::Log10, NumpyFunctions::Log10), + (StaticStrings::Std, NumpyFunctions::Std), + (StaticStrings::Sin, NumpyFunctions::Sin), + (StaticStrings::Cos, NumpyFunctions::Cos), + (StaticStrings::Tan, NumpyFunctions::Tan), + (StaticStrings::Log2, NumpyFunctions::Log2), + (StaticStrings::NpPower, NumpyFunctions::Power), + (StaticStrings::Pow, NumpyFunctions::Power), // alias + (StaticStrings::NpDiff, NumpyFunctions::Diff), + (StaticStrings::NpEdiff1d, NumpyFunctions::Ediff1d), + (StaticStrings::NpFull, NumpyFunctions::Full), + (StaticStrings::NpEye, NumpyFunctions::Eye), + (StaticStrings::Copy, NumpyFunctions::NpCopy), + (StaticStrings::NpEmpty, NumpyFunctions::Empty), + (StaticStrings::NpZerosLike, NumpyFunctions::ZerosLike), + (StaticStrings::NpOnesLike, NumpyFunctions::OnesLike), + (StaticStrings::Isnan, NumpyFunctions::Isnan), + (StaticStrings::Isinf, NumpyFunctions::Isinf), + (StaticStrings::NpIsposinf, NumpyFunctions::Isposinf), + (StaticStrings::NpIsneginf, NumpyFunctions::Isneginf), + (StaticStrings::Isfinite, NumpyFunctions::Isfinite), + (StaticStrings::NpArrayEqual, NumpyFunctions::ArrayEqual), + (StaticStrings::NpArrayEquiv, NumpyFunctions::ArrayEquiv), + (StaticStrings::NpCountNonzero, NumpyFunctions::CountNonzero), + (StaticStrings::NpAll, NumpyFunctions::All), + (StaticStrings::NpAny, NumpyFunctions::Any), + (StaticStrings::NpProd, NumpyFunctions::Prod), + (StaticStrings::NpVar, NumpyFunctions::Var), + (StaticStrings::NpMedian, NumpyFunctions::Median), + (StaticStrings::Argmin, NumpyFunctions::Argmin), + (StaticStrings::Argmax, NumpyFunctions::Argmax), + (StaticStrings::Reshape, NumpyFunctions::Reshape), + // np.flatten doesn't exist in real NumPy + (StaticStrings::NpTranspose, NumpyFunctions::Transpose), + (StaticStrings::NpTake, NumpyFunctions::Take), + (StaticStrings::NpTakeAlongAxis, NumpyFunctions::TakeAlongAxis), + (StaticStrings::NpResize, NumpyFunctions::Resize), + (StaticStrings::NpCompress, NumpyFunctions::Compress), + (StaticStrings::NpSwapaxes, NumpyFunctions::Swapaxes), + (StaticStrings::NpPermuteDims, NumpyFunctions::PermuteDims), + (StaticStrings::NpMatrixTranspose, NumpyFunctions::MatrixTranspose), + (StaticStrings::NpMoveaxis, NumpyFunctions::Moveaxis), + (StaticStrings::NpRollaxis, NumpyFunctions::Rollaxis), + (StaticStrings::NpRot90, NumpyFunctions::Rot90), + (StaticStrings::NpChoose, NumpyFunctions::Choose), + (StaticStrings::NpFillDiagonal, NumpyFunctions::FillDiagonal), + (StaticStrings::NpPut, NumpyFunctions::Put), + (StaticStrings::NpPutAlongAxis, NumpyFunctions::PutAlongAxis), + (StaticStrings::NpCopyto, NumpyFunctions::Copyto), + (StaticStrings::NpPutmask, NumpyFunctions::Putmask), + (StaticStrings::NpPlace, NumpyFunctions::Place), + (StaticStrings::Append, NumpyFunctions::Append), + (StaticStrings::NpVstack, NumpyFunctions::Vstack), + (StaticStrings::NpHstack, NumpyFunctions::Hstack), + (StaticStrings::NpDstack, NumpyFunctions::Dstack), + (StaticStrings::NpStack, NumpyFunctions::Stack), + (StaticStrings::NpBlock, NumpyFunctions::Block), + (StaticStrings::NpApplyAlongAxis, NumpyFunctions::ApplyAlongAxis), + (StaticStrings::NpApplyOverAxes, NumpyFunctions::ApplyOverAxes), + (StaticStrings::NpPiecewise, NumpyFunctions::Piecewise), + (StaticStrings::NpPad, NumpyFunctions::Pad), + (StaticStrings::NpUnstack, NumpyFunctions::Unstack), + (StaticStrings::NpNonzero, NumpyFunctions::Nonzero), + (StaticStrings::NpArgwhere, NumpyFunctions::Argwhere), + (StaticStrings::NpTile, NumpyFunctions::Tile), + (StaticStrings::NpRepeat, NumpyFunctions::Repeat), + (StaticStrings::Split, NumpyFunctions::Split), + (StaticStrings::NpShape, NumpyFunctions::Shape), + (StaticStrings::NpSize, NumpyFunctions::Size), + (StaticStrings::NpNdim, NumpyFunctions::Ndim), + (StaticStrings::NpBroadcastShapes, NumpyFunctions::BroadcastShapes), + (StaticStrings::NpBroadcastTo, NumpyFunctions::BroadcastTo), + (StaticStrings::NpBroadcastArrays, NumpyFunctions::BroadcastArrays), + (StaticStrings::NpBroadcast, NumpyFunctions::Broadcast), + (StaticStrings::Dtype, NumpyFunctions::Dtype), + (StaticStrings::NpAstype, NumpyFunctions::Astype), + ( + StaticStrings::NpFormatFloatPositional, + NumpyFunctions::FormatFloatPositional, + ), + ( + StaticStrings::NpFormatFloatScientific, + NumpyFunctions::FormatFloatScientific, + ), + // Phase 3: Inverse trig, hyperbolic, remaining math + (StaticStrings::NpArcsin, NumpyFunctions::Arcsin), + (StaticStrings::Asin, NumpyFunctions::Arcsin), // alias + (StaticStrings::NpArccos, NumpyFunctions::Arccos), + (StaticStrings::Acos, NumpyFunctions::Arccos), // alias + (StaticStrings::NpArctan, NumpyFunctions::Arctan), + (StaticStrings::Atan, NumpyFunctions::Arctan), // alias + (StaticStrings::NpArctan2, NumpyFunctions::Arctan2), + (StaticStrings::Atan2, NumpyFunctions::Arctan2), // alias + (StaticStrings::NpAngle, NumpyFunctions::Angle), + (StaticStrings::Sinh, NumpyFunctions::Sinh), + (StaticStrings::Cosh, NumpyFunctions::Cosh), + (StaticStrings::Tanh, NumpyFunctions::Tanh), + (StaticStrings::NpArcsinh, NumpyFunctions::Arcsinh), + (StaticStrings::Asinh, NumpyFunctions::Arcsinh), // alias + (StaticStrings::NpArccosh, NumpyFunctions::Arccosh), + (StaticStrings::Acosh, NumpyFunctions::Arccosh), // alias + (StaticStrings::NpArctanh, NumpyFunctions::Arctanh), + (StaticStrings::Atanh, NumpyFunctions::Arctanh), // alias + (StaticStrings::NpSign, NumpyFunctions::Sign), + (StaticStrings::NpSquare, NumpyFunctions::Square), + (StaticStrings::Cbrt, NumpyFunctions::Cbrt), + (StaticStrings::NpReciprocal, NumpyFunctions::Reciprocal), + (StaticStrings::Log1p, NumpyFunctions::Log1p), + (StaticStrings::Exp2, NumpyFunctions::Exp2), + (StaticStrings::Expm1, NumpyFunctions::Expm1), + (StaticStrings::NpDeg2rad, NumpyFunctions::Deg2rad), + (StaticStrings::NpRad2deg, NumpyFunctions::Rad2deg), + (StaticStrings::Degrees, NumpyFunctions::Rad2deg), // alias + (StaticStrings::Radians, NumpyFunctions::Deg2rad), // alias + (StaticStrings::NpHypot, NumpyFunctions::Hypot), + (StaticStrings::NpNanToNum, NumpyFunctions::NanToNum), + (StaticStrings::NpFmin, NumpyFunctions::Fmin), + (StaticStrings::NpFmax, NumpyFunctions::Fmax), + (StaticStrings::Fmod, NumpyFunctions::Fmod), + (StaticStrings::NpRint, NumpyFunctions::Rint), + (StaticStrings::Fabs, NumpyFunctions::Fabs), + (StaticStrings::NpPositive, NumpyFunctions::Positive), + (StaticStrings::NpNegative, NumpyFunctions::Negative), + (StaticStrings::Copysign, NumpyFunctions::Copysign), + (StaticStrings::Frexp, NumpyFunctions::Frexp), + (StaticStrings::Modf, NumpyFunctions::Modf), + (StaticStrings::Ldexp, NumpyFunctions::Ldexp), + (StaticStrings::Gcd, NumpyFunctions::Gcd), + (StaticStrings::Lcm, NumpyFunctions::Lcm), + (StaticStrings::NpLogaddexp, NumpyFunctions::Logaddexp), + (StaticStrings::NpLogaddexp2, NumpyFunctions::Logaddexp2), + (StaticStrings::Nextafter, NumpyFunctions::Nextafter), + (StaticStrings::NpSpacing, NumpyFunctions::Spacing), + (StaticStrings::NpSignbit, NumpyFunctions::Signbit), + (StaticStrings::NpSinc, NumpyFunctions::Sinc), + (StaticStrings::NpHeaviside, NumpyFunctions::Heaviside), + (StaticStrings::Trunc, NumpyFunctions::Trunc), + (StaticStrings::NpFix, NumpyFunctions::Fix), + (StaticStrings::NpFloatPower, NumpyFunctions::FloatPower), + (StaticStrings::NpDivmod, NumpyFunctions::Divmod), + (StaticStrings::NpBitwiseAnd, NumpyFunctions::BitwiseAnd), + (StaticStrings::NpBitwiseOr, NumpyFunctions::BitwiseOr), + (StaticStrings::NpBitwiseXor, NumpyFunctions::BitwiseXor), + (StaticStrings::NpBitwiseNot, NumpyFunctions::BitwiseNot), + (StaticStrings::NpBitwiseInvert, NumpyFunctions::BitwiseNot), // alias + (StaticStrings::NpInvert, NumpyFunctions::BitwiseNot), // alias + (StaticStrings::NpLeftShift, NumpyFunctions::LeftShift), + (StaticStrings::NpBitwiseLeftShift, NumpyFunctions::LeftShift), // alias + (StaticStrings::NpRightShift, NumpyFunctions::RightShift), + (StaticStrings::NpBitwiseRightShift, NumpyFunctions::RightShift), // alias + (StaticStrings::NpBitwiseCount, NumpyFunctions::BitwiseCount), + (StaticStrings::NpPackbits, NumpyFunctions::Packbits), + (StaticStrings::NpUnpackbits, NumpyFunctions::Unpackbits), + (StaticStrings::NpBartlett, NumpyFunctions::Bartlett), + (StaticStrings::NpBlackman, NumpyFunctions::Blackman), + (StaticStrings::NpHamming, NumpyFunctions::Hamming), + (StaticStrings::NpHanning, NumpyFunctions::Hanning), + (StaticStrings::NpKaiser, NumpyFunctions::Kaiser), + (StaticStrings::NpI0, NumpyFunctions::I0), + (StaticStrings::NpBaseRepr, NumpyFunctions::BaseRepr), + (StaticStrings::NpBinaryRepr, NumpyFunctions::BinaryRepr), + // Real-only aliases and introspection helpers + (StaticStrings::NpConj, NumpyFunctions::Conj), + (StaticStrings::NpConjugate, NumpyFunctions::Conj), // alias + (StaticStrings::NpReal, NumpyFunctions::Real), + (StaticStrings::NpRealIfClose, NumpyFunctions::RealIfClose), + (StaticStrings::NpImag, NumpyFunctions::Imag), + (StaticStrings::NpIsreal, NumpyFunctions::Isreal), + (StaticStrings::NpIsrealobj, NumpyFunctions::Isrealobj), + (StaticStrings::NpIscomplex, NumpyFunctions::Iscomplex), + (StaticStrings::NpIscomplexobj, NumpyFunctions::Iscomplexobj), + (StaticStrings::NpIsscalar, NumpyFunctions::Isscalar), + (StaticStrings::NpIterable, NumpyFunctions::Iterable), + (StaticStrings::NpCanCast, NumpyFunctions::CanCast), + (StaticStrings::NpPromoteTypes, NumpyFunctions::PromoteTypes), + (StaticStrings::NpResultType, NumpyFunctions::ResultType), + (StaticStrings::NpCommonType, NumpyFunctions::CommonType), + (StaticStrings::NpMinScalarType, NumpyFunctions::MinScalarType), + (StaticStrings::NpMintypecode, NumpyFunctions::Mintypecode), + (StaticStrings::NpTypename, NumpyFunctions::Typename), + (StaticStrings::NpInfo, NumpyFunctions::Info), + (StaticStrings::NpIssubdtype, NumpyFunctions::Issubdtype), + (StaticStrings::NpIsdtype, NumpyFunctions::Isdtype), + (StaticStrings::NpFinfo, NumpyFunctions::Finfo), + (StaticStrings::NpIinfo, NumpyFunctions::Iinfo), + (StaticStrings::NpGeterr, NumpyFunctions::Geterr), + (StaticStrings::NpSeterr, NumpyFunctions::Seterr), + (StaticStrings::NpGeterrcall, NumpyFunctions::Geterrcall), + (StaticStrings::NpSeterrcall, NumpyFunctions::Seterrcall), + (StaticStrings::NpErrstate, NumpyFunctions::Errstate), + (StaticStrings::NpGetPrintoptions, NumpyFunctions::GetPrintoptions), + (StaticStrings::NpSetPrintoptions, NumpyFunctions::SetPrintoptions), + (StaticStrings::NpPrintoptions, NumpyFunctions::Printoptions), + (StaticStrings::NpGetbufsize, NumpyFunctions::Getbufsize), + (StaticStrings::NpSetbufsize, NumpyFunctions::Setbufsize), + (StaticStrings::NpShowRuntime, NumpyFunctions::ShowRuntime), + (StaticStrings::NpTest, NumpyFunctions::Test), + (StaticStrings::NpAtleast1d, NumpyFunctions::Atleast1d), + (StaticStrings::NpAtleast2d, NumpyFunctions::Atleast2d), + (StaticStrings::NpAtleast3d, NumpyFunctions::Atleast3d), + (StaticStrings::NpDiagIndices, NumpyFunctions::DiagIndices), + (StaticStrings::NpDiagIndicesFrom, NumpyFunctions::DiagIndicesFrom), + (StaticStrings::NpTrilIndices, NumpyFunctions::TrilIndices), + (StaticStrings::NpTrilIndicesFrom, NumpyFunctions::TrilIndicesFrom), + (StaticStrings::NpTriuIndices, NumpyFunctions::TriuIndices), + (StaticStrings::NpTriuIndicesFrom, NumpyFunctions::TriuIndicesFrom), + (StaticStrings::NpIndices, NumpyFunctions::Indices), + (StaticStrings::NpUnravelIndex, NumpyFunctions::UnravelIndex), + (StaticStrings::NpRavelMultiIndex, NumpyFunctions::RavelMultiIndex), + (StaticStrings::NpNdindex, NumpyFunctions::Ndindex), + (StaticStrings::NpNdenumerate, NumpyFunctions::Ndenumerate), + (StaticStrings::NpNditer, NumpyFunctions::Nditer), + // Phase 4: NaN-aware aggregations and statistics + (StaticStrings::NpNansum, NumpyFunctions::Nansum), + (StaticStrings::NpNanmean, NumpyFunctions::Nanmean), + (StaticStrings::NpNanmin, NumpyFunctions::Nanmin), + (StaticStrings::NpNanmax, NumpyFunctions::Nanmax), + (StaticStrings::NpNanstd, NumpyFunctions::Nanstd), + (StaticStrings::NpNanvar, NumpyFunctions::Nanvar), + (StaticStrings::NpNanprod, NumpyFunctions::Nanprod), + (StaticStrings::NpNanmedian, NumpyFunctions::Nanmedian), + (StaticStrings::NpNanargmin, NumpyFunctions::Nanargmin), + (StaticStrings::NpNanargmax, NumpyFunctions::Nanargmax), + (StaticStrings::NpAverage, NumpyFunctions::Average), + (StaticStrings::NpPercentile, NumpyFunctions::Percentile), + (StaticStrings::NpQuantile, NumpyFunctions::Quantile), + (StaticStrings::NpNanpercentile, NumpyFunctions::Nanpercentile), + (StaticStrings::NpNanquantile, NumpyFunctions::Nanquantile), + (StaticStrings::NpHistogram, NumpyFunctions::Histogram), + (StaticStrings::NpHistogram2d, NumpyFunctions::Histogram2d), + (StaticStrings::NpHistogramBinEdges, NumpyFunctions::HistogramBinEdges), + (StaticStrings::NpHistogramdd, NumpyFunctions::Histogramdd), + (StaticStrings::NpPtp, NumpyFunctions::Ptp), + (StaticStrings::NpCumprod, NumpyFunctions::Cumprod), + (StaticStrings::NpCumulativeProd, NumpyFunctions::Cumprod), // alias + (StaticStrings::NpNancumsum, NumpyFunctions::Nancumsum), + (StaticStrings::NpNancumprod, NumpyFunctions::Nancumprod), + // Phase 5: Logical and testing + (StaticStrings::NpLogicalAnd, NumpyFunctions::LogicalAnd), + (StaticStrings::NpLogicalOr, NumpyFunctions::LogicalOr), + (StaticStrings::NpLogicalNot, NumpyFunctions::LogicalNot), + (StaticStrings::NpLogicalXor, NumpyFunctions::LogicalXor), + (StaticStrings::NpAllclose, NumpyFunctions::Allclose), + (StaticStrings::Isclose, NumpyFunctions::Isclose), + (StaticStrings::NpIsin, NumpyFunctions::Isin), + // Phase 6: Manipulation and shape + (StaticStrings::NpFlip, NumpyFunctions::Flip), + (StaticStrings::NpFliplr, NumpyFunctions::Fliplr), + (StaticStrings::NpFlipud, NumpyFunctions::Flipud), + (StaticStrings::NpRoll, NumpyFunctions::Roll), + (StaticStrings::NpExpandDims, NumpyFunctions::ExpandDims), + (StaticStrings::NpSqueeze, NumpyFunctions::Squeeze), + (StaticStrings::NpRavel, NumpyFunctions::Ravel), + (StaticStrings::NpDelete, NumpyFunctions::Delete), + (StaticStrings::Insert, NumpyFunctions::Insert), + (StaticStrings::NpDiag, NumpyFunctions::Diag), + (StaticStrings::NpDiagflat, NumpyFunctions::Diagflat), + (StaticStrings::NpDiagonal, NumpyFunctions::Diagonal), + (StaticStrings::NpTrace, NumpyFunctions::Trace), + (StaticStrings::NpFlatnonzero, NumpyFunctions::Flatnonzero), + (StaticStrings::NpAsarray, NumpyFunctions::Asarray), + (StaticStrings::NpAsarrayChkfinite, NumpyFunctions::AsarrayChkfinite), + (StaticStrings::NpAscontiguousarray, NumpyFunctions::Ascontiguousarray), + (StaticStrings::NpAsfortranarray, NumpyFunctions::Asfortranarray), + (StaticStrings::NpRequire, NumpyFunctions::Require), + (StaticStrings::NpIx_, NumpyFunctions::Ix), + (StaticStrings::NpMaskIndices, NumpyFunctions::MaskIndices), + (StaticStrings::NpIsfortran, NumpyFunctions::Isfortran), + (StaticStrings::NpMayShareMemory, NumpyFunctions::MayShareMemory), + (StaticStrings::NpSharesMemory, NumpyFunctions::SharesMemory), + (StaticStrings::NpColumnStack, NumpyFunctions::ColumnStack), + (StaticStrings::NpRowStack, NumpyFunctions::RowStack), + (StaticStrings::NpHsplit, NumpyFunctions::Hsplit), + (StaticStrings::NpVsplit, NumpyFunctions::Vsplit), + (StaticStrings::NpDsplit, NumpyFunctions::Dsplit), + (StaticStrings::NpArraySplit, NumpyFunctions::ArraySplit), + (StaticStrings::NpFullLike, NumpyFunctions::FullLike), + (StaticStrings::NpEmptyLike, NumpyFunctions::EmptyLike), + // Phase 7: Sorting, searching, set ops + (StaticStrings::NpArgsort, NumpyFunctions::ArgsortMod), + (StaticStrings::NpArgpartition, NumpyFunctions::Argpartition), + (StaticStrings::Partition, NumpyFunctions::Partition), + (StaticStrings::NpLexsort, NumpyFunctions::Lexsort), + (StaticStrings::NpCov, NumpyFunctions::Cov), + (StaticStrings::NpCorrcoef, NumpyFunctions::Corrcoef), + (StaticStrings::NpSearchsorted, NumpyFunctions::Searchsorted), + (StaticStrings::NpExtract, NumpyFunctions::Extract), + (StaticStrings::NpTrimZeros, NumpyFunctions::TrimZeros), + (StaticStrings::NpUnwrap, NumpyFunctions::Unwrap), + (StaticStrings::NpIntersect1d, NumpyFunctions::Intersect1d), + (StaticStrings::NpUnion1d, NumpyFunctions::Union1d), + (StaticStrings::NpSetdiff1d, NumpyFunctions::Setdiff1d), + (StaticStrings::NpSetxor1d, NumpyFunctions::Setxor1d), + (StaticStrings::NpBincount, NumpyFunctions::Bincount), + (StaticStrings::NpDigitize, NumpyFunctions::Digitize), + // Phase 8: Linear algebra + (StaticStrings::NpMatmul, NumpyFunctions::Matmul), + (StaticStrings::NpInner, NumpyFunctions::Inner), + (StaticStrings::NpOuter, NumpyFunctions::Outer), + (StaticStrings::NpVdot, NumpyFunctions::Vdot), + (StaticStrings::NpVecdot, NumpyFunctions::Vecdot), + (StaticStrings::NpMatvec, NumpyFunctions::Matvec), + (StaticStrings::NpVecmat, NumpyFunctions::Vecmat), + (StaticStrings::NpCross, NumpyFunctions::Cross), + (StaticStrings::NpKron, NumpyFunctions::Kron), + (StaticStrings::NpTensordot, NumpyFunctions::Tensordot), + (StaticStrings::NpEinsum, NumpyFunctions::Einsum), + (StaticStrings::NpEinsumPath, NumpyFunctions::EinsumPath), + (StaticStrings::NpTrapezoid, NumpyFunctions::Trapezoid), + (StaticStrings::NpVander, NumpyFunctions::Vander), + (StaticStrings::NpPoly, NumpyFunctions::Poly), + (StaticStrings::NpPolyadd, NumpyFunctions::Polyadd), + (StaticStrings::NpPolysub, NumpyFunctions::Polysub), + (StaticStrings::NpPolymul, NumpyFunctions::Polymul), + (StaticStrings::NpPolydiv, NumpyFunctions::Polydiv), + (StaticStrings::NpPolyint, NumpyFunctions::Polyint), + (StaticStrings::NpPolyder, NumpyFunctions::Polyder), + (StaticStrings::NpPolyval, NumpyFunctions::Polyval), + // Phase 10: Additional creation and numerical + (StaticStrings::NpLogspace, NumpyFunctions::Logspace), + (StaticStrings::NpGeomspace, NumpyFunctions::Geomspace), + (StaticStrings::NpTri, NumpyFunctions::Tri), + (StaticStrings::NpTril, NumpyFunctions::Tril), + (StaticStrings::NpTriu, NumpyFunctions::Triu), + (StaticStrings::NpIdentity, NumpyFunctions::Identity), + (StaticStrings::NpMeshgrid, NumpyFunctions::Meshgrid), + (StaticStrings::NpGradient, NumpyFunctions::Gradient), + (StaticStrings::NpConvolve, NumpyFunctions::Convolve), + (StaticStrings::NpCorrelate, NumpyFunctions::Correlate), + (StaticStrings::NpInterp, NumpyFunctions::Interp), + (StaticStrings::NpSelect, NumpyFunctions::Select), +]; + +/// Dispatches a call to a `numpy` module function. +pub(super) fn call( + vm: &mut VM<'_, impl ResourceTracker>, + function: NumpyFunctions, + args: ArgValues, +) -> RunResult { + match function { + NumpyFunctions::Array => call_array(vm, args).map(CallResult::Value), + NumpyFunctions::Array2string => call_array2string(vm, args).map(CallResult::Value), + NumpyFunctions::ArrayRepr => call_array_repr(vm, args).map(CallResult::Value), + NumpyFunctions::ArrayStr => call_array_str(vm, args).map(CallResult::Value), + NumpyFunctions::Fromfunction => call_fromfunction(vm, args).map(CallResult::Value), + NumpyFunctions::Fromiter => call_fromiter(vm, args).map(CallResult::Value), + NumpyFunctions::Fromstring => call_fromstring(vm, args).map(CallResult::Value), + NumpyFunctions::Zeros => call_zeros(vm, args).map(CallResult::Value), + NumpyFunctions::Ones => call_ones(vm, args).map(CallResult::Value), + NumpyFunctions::Arange => call_arange(vm, args).map(CallResult::Value), + NumpyFunctions::Linspace => call_linspace(vm, args).map(CallResult::Value), + NumpyFunctions::Sum => call_aggregate(vm, args, NdArray::sum, "numpy.sum").map(CallResult::Value), + NumpyFunctions::Mean => call_aggregate(vm, args, NdArray::mean, "numpy.mean").map(CallResult::Value), + NumpyFunctions::Min => call_aggregate_result(vm, args, NdArray::min_val, "numpy.min").map(CallResult::Value), + NumpyFunctions::Max => call_aggregate_result(vm, args, NdArray::max_val, "numpy.max").map(CallResult::Value), + NumpyFunctions::Abs => call_elementwise(vm, args, f64::abs, "numpy.abs", None).map(CallResult::Value), + NumpyFunctions::Sqrt => { + call_elementwise(vm, args, f64::sqrt, "numpy.sqrt", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Log => { + call_elementwise(vm, args, f64::ln, "numpy.log", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Exp => { + call_elementwise(vm, args, f64::exp, "numpy.exp", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Ceil => { + call_elementwise(vm, args, f64::ceil, "numpy.ceil", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Floor => { + call_elementwise(vm, args, f64::floor, "numpy.floor", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Log10 => { + call_elementwise(vm, args, f64::log10, "numpy.log10", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Round => call_round(vm, args).map(CallResult::Value), + NumpyFunctions::Clip => call_clip(vm, args).map(CallResult::Value), + NumpyFunctions::Where => call_where(vm, args).map(CallResult::Value), + NumpyFunctions::Maximum => call_pairwise(vm, args, f64::max, "numpy.maximum").map(CallResult::Value), + NumpyFunctions::Minimum => call_pairwise(vm, args, f64::min, "numpy.minimum").map(CallResult::Value), + NumpyFunctions::Sort => call_sort(vm, args).map(CallResult::Value), + NumpyFunctions::Unique => call_unique(vm, args).map(CallResult::Value), + NumpyFunctions::UniqueValues => call_unique_result(vm, args, UniqueResultKind::Values).map(CallResult::Value), + NumpyFunctions::UniqueCounts => call_unique_result(vm, args, UniqueResultKind::Counts).map(CallResult::Value), + NumpyFunctions::UniqueInverse => call_unique_result(vm, args, UniqueResultKind::Inverse).map(CallResult::Value), + NumpyFunctions::UniqueAll => call_unique_result(vm, args, UniqueResultKind::All).map(CallResult::Value), + NumpyFunctions::Concatenate => call_concatenate(vm, args).map(CallResult::Value), + NumpyFunctions::Cumsum => call_cumsum(vm, args).map(CallResult::Value), + NumpyFunctions::Dot => call_dot(vm, args).map(CallResult::Value), + NumpyFunctions::Std => call_aggregate(vm, args, NdArray::std_dev, "numpy.std").map(CallResult::Value), + NumpyFunctions::Sin => { + call_elementwise(vm, args, f64::sin, "numpy.sin", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Cos => { + call_elementwise(vm, args, f64::cos, "numpy.cos", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Tan => { + call_elementwise(vm, args, f64::tan, "numpy.tan", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Log2 => { + call_elementwise(vm, args, f64::log2, "numpy.log2", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Power => call_power(vm, args).map(CallResult::Value), + NumpyFunctions::Diff => call_diff(vm, args).map(CallResult::Value), + NumpyFunctions::Ediff1d => call_ediff1d(vm, args).map(CallResult::Value), + NumpyFunctions::Full => call_full(vm, args).map(CallResult::Value), + NumpyFunctions::Eye => call_eye(vm, args).map(CallResult::Value), + NumpyFunctions::NpCopy => call_copy(vm, args).map(CallResult::Value), + NumpyFunctions::Empty => call_empty(vm, args).map(CallResult::Value), + NumpyFunctions::ZerosLike => call_like(vm, args, 0.0, "numpy.zeros_like").map(CallResult::Value), + NumpyFunctions::OnesLike => call_like(vm, args, 1.0, "numpy.ones_like").map(CallResult::Value), + NumpyFunctions::Isnan => call_bool_test(vm, args, f64::is_nan, "numpy.isnan").map(CallResult::Value), + NumpyFunctions::Isinf => call_bool_test(vm, args, f64::is_infinite, "numpy.isinf").map(CallResult::Value), + NumpyFunctions::Isposinf => call_bool_test(vm, args, f64_is_pos_inf, "numpy.isposinf").map(CallResult::Value), + NumpyFunctions::Isneginf => call_bool_test(vm, args, f64_is_neg_inf, "numpy.isneginf").map(CallResult::Value), + NumpyFunctions::Isfinite => call_bool_test(vm, args, f64::is_finite, "numpy.isfinite").map(CallResult::Value), + NumpyFunctions::ArrayEqual => call_array_equal(vm, args).map(CallResult::Value), + NumpyFunctions::ArrayEquiv => call_array_equiv(vm, args).map(CallResult::Value), + NumpyFunctions::CountNonzero => call_count_nonzero(vm, args).map(CallResult::Value), + NumpyFunctions::All => call_all(vm, args).map(CallResult::Value), + NumpyFunctions::Any => call_any(vm, args).map(CallResult::Value), + NumpyFunctions::Prod => call_prod(vm, args).map(CallResult::Value), + NumpyFunctions::Var => call_aggregate(vm, args, NdArray::var, "numpy.var").map(CallResult::Value), + NumpyFunctions::Median => call_median(vm, args).map(CallResult::Value), + NumpyFunctions::Argmin => call_argmin_mod(vm, args).map(CallResult::Value), + NumpyFunctions::Argmax => call_argmax_mod(vm, args).map(CallResult::Value), + NumpyFunctions::Reshape => call_reshape_mod(vm, args).map(CallResult::Value), + // np.flatten doesn't exist in real NumPy + NumpyFunctions::Transpose => call_transpose_mod(vm, args).map(CallResult::Value), + NumpyFunctions::Take => call_take_mod(vm, args).map(CallResult::Value), + NumpyFunctions::TakeAlongAxis => call_take_along_axis(vm, args).map(CallResult::Value), + NumpyFunctions::Resize => call_resize(vm, args).map(CallResult::Value), + NumpyFunctions::Compress => call_compress_mod(vm, args).map(CallResult::Value), + NumpyFunctions::Swapaxes => call_swapaxes_mod(vm, args).map(CallResult::Value), + NumpyFunctions::PermuteDims => call_permute_dims(vm, args).map(CallResult::Value), + NumpyFunctions::MatrixTranspose => call_matrix_transpose(vm, args).map(CallResult::Value), + NumpyFunctions::Moveaxis => call_moveaxis(vm, args).map(CallResult::Value), + NumpyFunctions::Rollaxis => call_rollaxis(vm, args).map(CallResult::Value), + NumpyFunctions::Rot90 => call_rot90(vm, args).map(CallResult::Value), + NumpyFunctions::Choose => call_choose(vm, args).map(CallResult::Value), + NumpyFunctions::FillDiagonal => call_fill_diagonal(vm, args).map(CallResult::Value), + NumpyFunctions::Put => call_put(vm, args).map(CallResult::Value), + NumpyFunctions::PutAlongAxis => call_put_along_axis(vm, args).map(CallResult::Value), + NumpyFunctions::Copyto => call_copyto(vm, args).map(CallResult::Value), + NumpyFunctions::Putmask => call_putmask(vm, args).map(CallResult::Value), + NumpyFunctions::Place => call_place(vm, args).map(CallResult::Value), + NumpyFunctions::Append => call_append(vm, args).map(CallResult::Value), + NumpyFunctions::Vstack => call_vstack(vm, args).map(CallResult::Value), + NumpyFunctions::Hstack => call_hstack(vm, args).map(CallResult::Value), + NumpyFunctions::Dstack => call_dstack(vm, args).map(CallResult::Value), + // Note: np.stack with axis=0 is equivalent to np.vstack for 1D inputs. + // For 2D+ inputs, np.stack creates a new axis, which differs from vstack. + // We only support the 1D case which is the LLM-common pattern. + NumpyFunctions::Stack => call_vstack(vm, args).map(CallResult::Value), + NumpyFunctions::Block => call_block(vm, args).map(CallResult::Value), + NumpyFunctions::ApplyAlongAxis => call_apply_along_axis(vm, args).map(CallResult::Value), + NumpyFunctions::ApplyOverAxes => call_apply_over_axes(vm, args).map(CallResult::Value), + NumpyFunctions::Piecewise => call_piecewise(vm, args).map(CallResult::Value), + NumpyFunctions::Pad => call_pad(vm, args).map(CallResult::Value), + NumpyFunctions::Unstack => call_unstack(vm, args).map(CallResult::Value), + NumpyFunctions::Nonzero => call_nonzero(vm, args).map(CallResult::Value), + NumpyFunctions::Argwhere => call_argwhere(vm, args).map(CallResult::Value), + NumpyFunctions::Tile => call_tile(vm, args).map(CallResult::Value), + NumpyFunctions::Repeat => call_repeat(vm, args).map(CallResult::Value), + NumpyFunctions::Split => call_split(vm, args).map(CallResult::Value), + NumpyFunctions::Add => { + call_numeric_binop(vm, args, |a, b| a + b, "numpy.add", BinopResult::Promoted).map(CallResult::Value) + } + NumpyFunctions::Subtract => { + call_numeric_binop(vm, args, |a, b| a - b, "numpy.subtract", BinopResult::Promoted).map(CallResult::Value) + } + NumpyFunctions::Multiply => { + call_numeric_binop(vm, args, |a, b| a * b, "numpy.multiply", BinopResult::Promoted).map(CallResult::Value) + } + NumpyFunctions::Divide => { + call_numeric_binop(vm, args, |a, b| a / b, "numpy.divide", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::FloorDivide => call_numeric_binop( + vm, + args, + |a, b| (a / b).floor(), + "numpy.floor_divide", + BinopResult::Promoted, + ) + .map(CallResult::Value), + NumpyFunctions::Mod => { + call_numeric_binop(vm, args, py_mod, "numpy.mod", BinopResult::Promoted).map(CallResult::Value) + } + NumpyFunctions::Equal => { + call_numeric_binop(vm, args, eq_to_f64, "numpy.equal", BinopResult::Bool).map(CallResult::Value) + } + NumpyFunctions::NotEqual => { + call_numeric_binop(vm, args, ne_to_f64, "numpy.not_equal", BinopResult::Bool).map(CallResult::Value) + } + NumpyFunctions::Greater => call_numeric_binop( + vm, + args, + |a, b| if a > b { 1.0 } else { 0.0 }, + "numpy.greater", + BinopResult::Bool, + ) + .map(CallResult::Value), + NumpyFunctions::GreaterEqual => call_numeric_binop( + vm, + args, + |a, b| if a >= b { 1.0 } else { 0.0 }, + "numpy.greater_equal", + BinopResult::Bool, + ) + .map(CallResult::Value), + NumpyFunctions::Less => call_numeric_binop( + vm, + args, + |a, b| if a < b { 1.0 } else { 0.0 }, + "numpy.less", + BinopResult::Bool, + ) + .map(CallResult::Value), + NumpyFunctions::LessEqual => call_numeric_binop( + vm, + args, + |a, b| if a <= b { 1.0 } else { 0.0 }, + "numpy.less_equal", + BinopResult::Bool, + ) + .map(CallResult::Value), + NumpyFunctions::Shape => call_shape(vm, args).map(CallResult::Value), + NumpyFunctions::Size => call_size(vm, args).map(CallResult::Value), + NumpyFunctions::Ndim => call_ndim(vm, args).map(CallResult::Value), + NumpyFunctions::BroadcastShapes => call_broadcast_shapes(vm, args).map(CallResult::Value), + NumpyFunctions::BroadcastTo => call_broadcast_to(vm, args).map(CallResult::Value), + NumpyFunctions::BroadcastArrays => call_broadcast_arrays(vm, args).map(CallResult::Value), + NumpyFunctions::Broadcast => call_broadcast(vm, args).map(CallResult::Value), + NumpyFunctions::Dtype => call_dtype(vm, args).map(CallResult::Value), + NumpyFunctions::Astype => call_astype(vm, args).map(CallResult::Value), + NumpyFunctions::FormatFloatPositional => call_format_float_positional(vm, args).map(CallResult::Value), + NumpyFunctions::FormatFloatScientific => call_format_float_scientific(vm, args).map(CallResult::Value), + // Phase 3: Inverse trig, hyperbolic, remaining math + NumpyFunctions::Arcsin => { + call_elementwise(vm, args, f64::asin, "numpy.arcsin", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Arccos => { + call_elementwise(vm, args, f64::acos, "numpy.arccos", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Arctan => { + call_elementwise(vm, args, f64::atan, "numpy.arctan", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Arctan2 => { + call_numeric_binop(vm, args, f64::atan2, "numpy.arctan2", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::Angle => call_angle(vm, args).map(CallResult::Value), + NumpyFunctions::Sinh => { + call_elementwise(vm, args, f64::sinh, "numpy.sinh", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Cosh => { + call_elementwise(vm, args, f64::cosh, "numpy.cosh", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Tanh => { + call_elementwise(vm, args, f64::tanh, "numpy.tanh", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Arcsinh => { + call_elementwise(vm, args, f64::asinh, "numpy.arcsinh", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Arccosh => { + call_elementwise(vm, args, f64::acosh, "numpy.arccosh", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Arctanh => { + call_elementwise(vm, args, f64::atanh, "numpy.arctanh", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Sign => { + // numpy.sign returns 0.0 for 0.0, unlike Rust's signum which returns 1.0 + call_elementwise( + vm, + args, + |x| if x == 0.0 { 0.0 } else { x.signum() }, + "numpy.sign", + None, + ) + .map(CallResult::Value) + } + NumpyFunctions::Square => call_elementwise(vm, args, |x| x * x, "numpy.square", None).map(CallResult::Value), + NumpyFunctions::Cbrt => { + call_elementwise(vm, args, f64::cbrt, "numpy.cbrt", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Reciprocal => { + call_elementwise(vm, args, |x| 1.0 / x, "numpy.reciprocal", Some(NdArrayDtype::Float64)) + .map(CallResult::Value) + } + NumpyFunctions::Log1p => { + call_elementwise(vm, args, f64::ln_1p, "numpy.log1p", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Exp2 => { + call_elementwise(vm, args, f64::exp2, "numpy.exp2", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Expm1 => { + call_elementwise(vm, args, f64::exp_m1, "numpy.expm1", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Deg2rad => { + call_elementwise(vm, args, f64::to_radians, "numpy.deg2rad", Some(NdArrayDtype::Float64)) + .map(CallResult::Value) + } + NumpyFunctions::Rad2deg => { + call_elementwise(vm, args, f64::to_degrees, "numpy.rad2deg", Some(NdArrayDtype::Float64)) + .map(CallResult::Value) + } + NumpyFunctions::Hypot => call_pairwise(vm, args, f64::hypot, "numpy.hypot").map(CallResult::Value), + NumpyFunctions::NanToNum => call_nan_to_num(vm, args).map(CallResult::Value), + NumpyFunctions::Fmin => call_pairwise( + vm, + args, + |a, b| { + if a.is_nan() { + b + } else if b.is_nan() { + a + } else { + a.min(b) + } + }, + "numpy.fmin", + ) + .map(CallResult::Value), + NumpyFunctions::Fmax => call_pairwise( + vm, + args, + |a, b| { + if a.is_nan() { + b + } else if b.is_nan() { + a + } else { + a.max(b) + } + }, + "numpy.fmax", + ) + .map(CallResult::Value), + NumpyFunctions::Fmod => call_pairwise(vm, args, |a, b| a % b, "numpy.fmod").map(CallResult::Value), + NumpyFunctions::Rint => call_elementwise( + vm, + args, + f64::round_ties_even, + "numpy.rint", + Some(NdArrayDtype::Float64), + ) + .map(CallResult::Value), + NumpyFunctions::Fabs => { + call_elementwise(vm, args, f64::abs, "numpy.fabs", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Positive => call_elementwise(vm, args, |x| x, "numpy.positive", None).map(CallResult::Value), + NumpyFunctions::Negative => call_elementwise(vm, args, |x| -x, "numpy.negative", None).map(CallResult::Value), + NumpyFunctions::Copysign => { + call_numeric_binop(vm, args, f64::copysign, "numpy.copysign", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::Frexp => call_unary_tuple_func( + vm, + args, + numpy_frexp, + "numpy.frexp", + NdArrayDtype::Float64, + NdArrayDtype::Int64, + ) + .map(CallResult::Value), + NumpyFunctions::Modf => call_unary_tuple_func( + vm, + args, + numpy_modf, + "numpy.modf", + NdArrayDtype::Float64, + NdArrayDtype::Float64, + ) + .map(CallResult::Value), + NumpyFunctions::Ldexp => call_ldexp(vm, args).map(CallResult::Value), + NumpyFunctions::Gcd => call_integer_binop(vm, args, numpy_gcd, "numpy.gcd").map(CallResult::Value), + NumpyFunctions::Lcm => call_integer_binop(vm, args, numpy_lcm, "numpy.lcm").map(CallResult::Value), + NumpyFunctions::Logaddexp => { + call_numeric_binop(vm, args, numpy_logaddexp, "numpy.logaddexp", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::Logaddexp2 => { + call_numeric_binop(vm, args, numpy_logaddexp2, "numpy.logaddexp2", BinopResult::Float) + .map(CallResult::Value) + } + NumpyFunctions::Nextafter => { + call_numeric_binop(vm, args, libm::nextafter, "numpy.nextafter", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::Spacing => { + call_elementwise(vm, args, numpy_spacing, "numpy.spacing", Some(NdArrayDtype::Float64)) + .map(CallResult::Value) + } + NumpyFunctions::Signbit => { + call_elementwise(vm, args, signbit_as_f64, "numpy.signbit", Some(NdArrayDtype::Bool)).map(CallResult::Value) + } + NumpyFunctions::Sinc => { + call_elementwise(vm, args, numpy_sinc, "numpy.sinc", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Heaviside => { + call_numeric_binop(vm, args, numpy_heaviside, "numpy.heaviside", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::Trunc => { + call_elementwise(vm, args, f64::trunc, "numpy.trunc", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::Fix => { + call_elementwise(vm, args, f64::trunc, "numpy.fix", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::FloatPower => { + call_numeric_binop(vm, args, f64::powf, "numpy.float_power", BinopResult::Float).map(CallResult::Value) + } + NumpyFunctions::Divmod => call_numeric_tuple_binop( + vm, + args, + numpy_divmod, + "numpy.divmod", + BinopResult::Promoted, + BinopResult::Promoted, + ) + .map(CallResult::Value), + NumpyFunctions::BitwiseAnd => { + call_bitwise_binop(vm, args, IntegerBitwiseOp::And, "numpy.bitwise_and").map(CallResult::Value) + } + NumpyFunctions::BitwiseOr => { + call_bitwise_binop(vm, args, IntegerBitwiseOp::Or, "numpy.bitwise_or").map(CallResult::Value) + } + NumpyFunctions::BitwiseXor => { + call_bitwise_binop(vm, args, IntegerBitwiseOp::Xor, "numpy.bitwise_xor").map(CallResult::Value) + } + NumpyFunctions::BitwiseNot => call_bitwise_not(vm, args).map(CallResult::Value), + NumpyFunctions::LeftShift => { + call_bitwise_binop(vm, args, IntegerBitwiseOp::LeftShift, "numpy.left_shift").map(CallResult::Value) + } + NumpyFunctions::RightShift => { + call_bitwise_binop(vm, args, IntegerBitwiseOp::RightShift, "numpy.right_shift").map(CallResult::Value) + } + NumpyFunctions::BitwiseCount => call_bitwise_count(vm, args).map(CallResult::Value), + NumpyFunctions::Packbits => call_packbits(vm, args).map(CallResult::Value), + NumpyFunctions::Unpackbits => call_unpackbits(vm, args).map(CallResult::Value), + NumpyFunctions::Bartlett => { + call_window(vm, args, WindowKind::Bartlett, "numpy.bartlett").map(CallResult::Value) + } + NumpyFunctions::Blackman => { + call_window(vm, args, WindowKind::Blackman, "numpy.blackman").map(CallResult::Value) + } + NumpyFunctions::Hamming => call_window(vm, args, WindowKind::Hamming, "numpy.hamming").map(CallResult::Value), + NumpyFunctions::Hanning => call_window(vm, args, WindowKind::Hanning, "numpy.hanning").map(CallResult::Value), + NumpyFunctions::Kaiser => call_kaiser(vm, args).map(CallResult::Value), + NumpyFunctions::I0 => { + call_elementwise(vm, args, numpy_i0, "numpy.i0", Some(NdArrayDtype::Float64)).map(CallResult::Value) + } + NumpyFunctions::BaseRepr => call_base_repr(vm, args).map(CallResult::Value), + NumpyFunctions::BinaryRepr => call_binary_repr(vm, args).map(CallResult::Value), + NumpyFunctions::Conj => call_real_identity(vm, args, "numpy.conj").map(CallResult::Value), + NumpyFunctions::Real => call_real_identity(vm, args, "numpy.real").map(CallResult::Value), + NumpyFunctions::RealIfClose => call_real_if_close(vm, args).map(CallResult::Value), + NumpyFunctions::Imag => call_imag(vm, args).map(CallResult::Value), + NumpyFunctions::Isreal => call_realness_elementwise(vm, args, true, "numpy.isreal").map(CallResult::Value), + NumpyFunctions::Isrealobj => call_realness_object(vm, args, true, "numpy.isrealobj").map(CallResult::Value), + NumpyFunctions::Iscomplex => { + call_realness_elementwise(vm, args, false, "numpy.iscomplex").map(CallResult::Value) + } + NumpyFunctions::Iscomplexobj => { + call_realness_object(vm, args, false, "numpy.iscomplexobj").map(CallResult::Value) + } + NumpyFunctions::Isscalar => call_isscalar(vm, args).map(CallResult::Value), + NumpyFunctions::Iterable => call_iterable(vm, args).map(CallResult::Value), + NumpyFunctions::CanCast => call_can_cast(vm, args).map(CallResult::Value), + NumpyFunctions::PromoteTypes => call_promote_types(vm, args).map(CallResult::Value), + NumpyFunctions::ResultType => call_result_type(vm, args).map(CallResult::Value), + NumpyFunctions::CommonType => call_common_type(vm, args).map(CallResult::Value), + NumpyFunctions::MinScalarType => call_min_scalar_type(vm, args).map(CallResult::Value), + NumpyFunctions::Mintypecode => call_mintypecode(vm, args).map(CallResult::Value), + NumpyFunctions::Typename => call_typename(vm, args).map(CallResult::Value), + NumpyFunctions::Info => Ok(CallResult::Value(call_info(vm, args))), + NumpyFunctions::Issubdtype => call_issubdtype(vm, args).map(CallResult::Value), + NumpyFunctions::Isdtype => call_isdtype(vm, args).map(CallResult::Value), + NumpyFunctions::Finfo => call_finfo(vm, args).map(CallResult::Value), + NumpyFunctions::Iinfo => call_iinfo(vm, args).map(CallResult::Value), + NumpyFunctions::Geterr => call_geterr(vm, args).map(CallResult::Value), + NumpyFunctions::Seterr => call_seterr(vm, args).map(CallResult::Value), + NumpyFunctions::Geterrcall => call_geterrcall(vm, args).map(CallResult::Value), + NumpyFunctions::Seterrcall => call_seterrcall(vm, args).map(CallResult::Value), + NumpyFunctions::Errstate => call_errstate(vm, args).map(CallResult::Value), + NumpyFunctions::GetPrintoptions => call_get_printoptions(vm, args).map(CallResult::Value), + NumpyFunctions::SetPrintoptions => Ok(CallResult::Value(call_set_printoptions(vm, args))), + NumpyFunctions::Printoptions => call_printoptions(vm, args).map(CallResult::Value), + NumpyFunctions::Getbufsize => call_getbufsize(vm, args).map(CallResult::Value), + NumpyFunctions::Setbufsize => call_setbufsize(vm, args).map(CallResult::Value), + NumpyFunctions::ShowRuntime => Ok(CallResult::Value(call_show_runtime(vm, args))), + NumpyFunctions::Test => Ok(CallResult::Value(call_test(vm, args))), + NumpyFunctions::Atleast1d => call_atleast_nd(vm, args, 1, "numpy.atleast_1d").map(CallResult::Value), + NumpyFunctions::Atleast2d => call_atleast_nd(vm, args, 2, "numpy.atleast_2d").map(CallResult::Value), + NumpyFunctions::Atleast3d => call_atleast_nd(vm, args, 3, "numpy.atleast_3d").map(CallResult::Value), + NumpyFunctions::DiagIndices => call_diag_indices(vm, args).map(CallResult::Value), + NumpyFunctions::DiagIndicesFrom => call_diag_indices_from(vm, args).map(CallResult::Value), + NumpyFunctions::TrilIndices => { + call_triangle_indices(vm, args, TriangleKind::Lower, "numpy.tril_indices").map(CallResult::Value) + } + NumpyFunctions::TrilIndicesFrom => { + call_triangle_indices_from(vm, args, TriangleKind::Lower, "numpy.tril_indices_from").map(CallResult::Value) + } + NumpyFunctions::TriuIndices => { + call_triangle_indices(vm, args, TriangleKind::Upper, "numpy.triu_indices").map(CallResult::Value) + } + NumpyFunctions::TriuIndicesFrom => { + call_triangle_indices_from(vm, args, TriangleKind::Upper, "numpy.triu_indices_from").map(CallResult::Value) + } + NumpyFunctions::Indices => call_indices(vm, args).map(CallResult::Value), + NumpyFunctions::UnravelIndex => call_unravel_index(vm, args).map(CallResult::Value), + NumpyFunctions::RavelMultiIndex => call_ravel_multi_index(vm, args).map(CallResult::Value), + NumpyFunctions::Ndindex => call_ndindex(vm, args).map(CallResult::Value), + NumpyFunctions::Ndenumerate => call_ndenumerate(vm, args).map(CallResult::Value), + NumpyFunctions::Nditer => call_nditer(vm, args).map(CallResult::Value), + // Phase 4: NaN-aware aggregations and statistics + NumpyFunctions::Nansum => call_nan_aggregate(vm, args, nan_sum, "numpy.nansum").map(CallResult::Value), + NumpyFunctions::Nanmean => call_nan_aggregate(vm, args, nan_mean, "numpy.nanmean").map(CallResult::Value), + NumpyFunctions::Nanmin => call_nan_aggregate(vm, args, nan_min, "numpy.nanmin").map(CallResult::Value), + NumpyFunctions::Nanmax => call_nan_aggregate(vm, args, nan_max, "numpy.nanmax").map(CallResult::Value), + NumpyFunctions::Nanstd => call_nan_aggregate(vm, args, nan_std, "numpy.nanstd").map(CallResult::Value), + NumpyFunctions::Nanvar => call_nan_aggregate(vm, args, nan_var, "numpy.nanvar").map(CallResult::Value), + NumpyFunctions::Nanprod => call_nan_aggregate(vm, args, nan_prod, "numpy.nanprod").map(CallResult::Value), + NumpyFunctions::Nanmedian => call_nan_aggregate(vm, args, nan_median, "numpy.nanmedian").map(CallResult::Value), + NumpyFunctions::Nanargmin => call_nan_argmin(vm, args).map(CallResult::Value), + NumpyFunctions::Nanargmax => call_nan_argmax(vm, args).map(CallResult::Value), + NumpyFunctions::Average => call_aggregate(vm, args, NdArray::mean, "numpy.average").map(CallResult::Value), + NumpyFunctions::Percentile => call_percentile(vm, args).map(CallResult::Value), + NumpyFunctions::Quantile => call_quantile(vm, args).map(CallResult::Value), + NumpyFunctions::Nanpercentile => call_nanpercentile(vm, args).map(CallResult::Value), + NumpyFunctions::Nanquantile => call_nanquantile(vm, args).map(CallResult::Value), + NumpyFunctions::Histogram => call_histogram(vm, args).map(CallResult::Value), + NumpyFunctions::Histogram2d => call_histogram2d(vm, args).map(CallResult::Value), + NumpyFunctions::HistogramBinEdges => call_histogram_bin_edges(vm, args).map(CallResult::Value), + NumpyFunctions::Histogramdd => call_histogramdd(vm, args).map(CallResult::Value), + NumpyFunctions::Ptp => call_ptp(vm, args).map(CallResult::Value), + NumpyFunctions::Cumprod => call_cumprod(vm, args).map(CallResult::Value), + NumpyFunctions::Nancumsum => call_nancumop(vm, args, true, "numpy.nancumsum").map(CallResult::Value), + NumpyFunctions::Nancumprod => call_nancumop(vm, args, false, "numpy.nancumprod").map(CallResult::Value), + // Phase 5: Logical and testing + NumpyFunctions::LogicalAnd => { + call_logical_binop(vm, args, |a, b| a && b, "numpy.logical_and").map(CallResult::Value) + } + NumpyFunctions::LogicalOr => { + call_logical_binop(vm, args, |a, b| a || b, "numpy.logical_or").map(CallResult::Value) + } + NumpyFunctions::LogicalNot => call_logical_not(vm, args).map(CallResult::Value), + NumpyFunctions::LogicalXor => { + call_logical_binop(vm, args, |a, b| a ^ b, "numpy.logical_xor").map(CallResult::Value) + } + NumpyFunctions::Allclose => call_allclose(vm, args).map(CallResult::Value), + NumpyFunctions::Isclose => call_isclose(vm, args).map(CallResult::Value), + NumpyFunctions::Isin => call_isin(vm, args).map(CallResult::Value), + // Phase 6: Manipulation and shape + NumpyFunctions::Flip => call_flip(vm, args).map(CallResult::Value), + NumpyFunctions::Fliplr => call_fliplr(vm, args).map(CallResult::Value), + NumpyFunctions::Flipud => call_flipud(vm, args).map(CallResult::Value), + NumpyFunctions::Roll => call_roll(vm, args).map(CallResult::Value), + NumpyFunctions::ExpandDims => call_expand_dims(vm, args).map(CallResult::Value), + NumpyFunctions::Squeeze => call_squeeze(vm, args).map(CallResult::Value), + NumpyFunctions::Ravel => call_ravel_mod(vm, args).map(CallResult::Value), + NumpyFunctions::Delete => call_delete(vm, args).map(CallResult::Value), + NumpyFunctions::Insert => call_insert(vm, args).map(CallResult::Value), + NumpyFunctions::Diag => call_diag(vm, args).map(CallResult::Value), + NumpyFunctions::Diagflat => call_diagflat(vm, args).map(CallResult::Value), + NumpyFunctions::Diagonal => call_diagonal(vm, args).map(CallResult::Value), + NumpyFunctions::Trace => call_trace(vm, args).map(CallResult::Value), + NumpyFunctions::Flatnonzero => call_flatnonzero(vm, args).map(CallResult::Value), + NumpyFunctions::Asarray => call_asarray(vm, args).map(CallResult::Value), + NumpyFunctions::AsarrayChkfinite => call_asarray_chkfinite(vm, args).map(CallResult::Value), + NumpyFunctions::Ascontiguousarray | NumpyFunctions::Asfortranarray | NumpyFunctions::Require => { + call_asarray_compat(vm, args).map(CallResult::Value) + } + NumpyFunctions::Ix => call_ix(vm, args).map(CallResult::Value), + NumpyFunctions::MaskIndices => call_mask_indices(vm, args).map(CallResult::Value), + NumpyFunctions::Isfortran => call_isfortran(vm, args).map(CallResult::Value), + NumpyFunctions::MayShareMemory => { + call_memory_overlap(vm, args, "numpy.may_share_memory").map(CallResult::Value) + } + NumpyFunctions::SharesMemory => call_memory_overlap(vm, args, "numpy.shares_memory").map(CallResult::Value), + NumpyFunctions::ColumnStack => call_column_stack(vm, args).map(CallResult::Value), + NumpyFunctions::RowStack => call_vstack(vm, args).map(CallResult::Value), // alias + NumpyFunctions::Hsplit => call_hsplit(vm, args).map(CallResult::Value), + NumpyFunctions::Vsplit => call_vsplit(vm, args).map(CallResult::Value), + NumpyFunctions::Dsplit => call_dsplit(vm, args).map(CallResult::Value), + NumpyFunctions::ArraySplit => call_array_split(vm, args).map(CallResult::Value), + NumpyFunctions::FullLike => call_full_like(vm, args).map(CallResult::Value), + NumpyFunctions::EmptyLike => call_like(vm, args, 0.0, "numpy.empty_like").map(CallResult::Value), + // Phase 7: Sorting, searching, set ops + NumpyFunctions::ArgsortMod => call_argsort_mod(vm, args).map(CallResult::Value), + NumpyFunctions::Argpartition => call_argpartition(vm, args).map(CallResult::Value), + NumpyFunctions::Partition => call_partition(vm, args).map(CallResult::Value), + NumpyFunctions::Lexsort => call_lexsort(vm, args).map(CallResult::Value), + NumpyFunctions::Cov => call_cov(vm, args).map(CallResult::Value), + NumpyFunctions::Corrcoef => call_corrcoef(vm, args).map(CallResult::Value), + NumpyFunctions::Searchsorted => call_searchsorted(vm, args).map(CallResult::Value), + NumpyFunctions::Extract => call_extract(vm, args).map(CallResult::Value), + NumpyFunctions::TrimZeros => call_trim_zeros(vm, args).map(CallResult::Value), + NumpyFunctions::Unwrap => call_unwrap(vm, args).map(CallResult::Value), + NumpyFunctions::Intersect1d => { + call_set_op(vm, args, SetOp::Intersect, "numpy.intersect1d").map(CallResult::Value) + } + NumpyFunctions::Union1d => call_set_op(vm, args, SetOp::Union, "numpy.union1d").map(CallResult::Value), + NumpyFunctions::Setdiff1d => call_set_op(vm, args, SetOp::Diff, "numpy.setdiff1d").map(CallResult::Value), + NumpyFunctions::Setxor1d => call_set_op(vm, args, SetOp::Xor, "numpy.setxor1d").map(CallResult::Value), + NumpyFunctions::Bincount => call_bincount(vm, args).map(CallResult::Value), + NumpyFunctions::Digitize => call_digitize(vm, args).map(CallResult::Value), + // Phase 8: Linear algebra + NumpyFunctions::Matmul => call_matmul(vm, args).map(CallResult::Value), + NumpyFunctions::Inner => call_dot(vm, args).map(CallResult::Value), // For 1D, inner = dot + NumpyFunctions::Outer => call_outer(vm, args).map(CallResult::Value), + NumpyFunctions::Vdot => call_dot(vm, args).map(CallResult::Value), // vdot flattens first, same as dot for 1D + NumpyFunctions::Vecdot => call_dot(vm, args).map(CallResult::Value), // 1D vector subset + NumpyFunctions::Matvec | NumpyFunctions::Vecmat => call_matmul(vm, args).map(CallResult::Value), + NumpyFunctions::Cross => call_cross(vm, args).map(CallResult::Value), + NumpyFunctions::Kron => call_kron(vm, args).map(CallResult::Value), + NumpyFunctions::Tensordot => call_tensordot(vm, args).map(CallResult::Value), + NumpyFunctions::Einsum => call_einsum(vm, args).map(CallResult::Value), + NumpyFunctions::EinsumPath => call_einsum_path(vm, args).map(CallResult::Value), + NumpyFunctions::Trapezoid => call_trapezoid(vm, args).map(CallResult::Value), + NumpyFunctions::Vander => call_vander(vm, args).map(CallResult::Value), + NumpyFunctions::Poly => call_poly(vm, args).map(CallResult::Value), + NumpyFunctions::Polyadd => { + call_poly_binary(vm, args, "numpy.polyadd", |lhs, rhs| lhs + rhs).map(CallResult::Value) + } + NumpyFunctions::Polysub => { + call_poly_binary(vm, args, "numpy.polysub", |lhs, rhs| lhs - rhs).map(CallResult::Value) + } + NumpyFunctions::Polymul => call_polymul(vm, args).map(CallResult::Value), + NumpyFunctions::Polydiv => call_polydiv(vm, args).map(CallResult::Value), + NumpyFunctions::Polyint => call_polyint(vm, args).map(CallResult::Value), + NumpyFunctions::Polyder => call_polyder(vm, args).map(CallResult::Value), + NumpyFunctions::Polyval => call_polyval(vm, args).map(CallResult::Value), + // Phase 10: Additional creation and numerical + NumpyFunctions::Logspace => call_logspace(vm, args).map(CallResult::Value), + NumpyFunctions::Geomspace => call_geomspace(vm, args).map(CallResult::Value), + NumpyFunctions::Tri => call_tri(vm, args).map(CallResult::Value), + NumpyFunctions::Tril => call_tril(vm, args).map(CallResult::Value), + NumpyFunctions::Triu => call_triu(vm, args).map(CallResult::Value), + NumpyFunctions::Identity => call_eye(vm, args).map(CallResult::Value), // alias + NumpyFunctions::Meshgrid => call_meshgrid(vm, args).map(CallResult::Value), + NumpyFunctions::Gradient => call_gradient(vm, args).map(CallResult::Value), + NumpyFunctions::Convolve => call_convolve(vm, args).map(CallResult::Value), + NumpyFunctions::Correlate => call_correlate(vm, args).map(CallResult::Value), + NumpyFunctions::Interp => call_interp(vm, args).map(CallResult::Value), + NumpyFunctions::Select => call_select(vm, args).map(CallResult::Value), + } +} + +// =========================== +// Array creation functions +// =========================== + +/// `numpy.array(data)` — create an ndarray from a list or nested list. +fn call_array(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.array", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_list(arg, vm.heap)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.fromfunction(function, shape, dtype=float, **kwargs)` — call a function with coordinate arrays. +fn call_fromfunction(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(function) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.fromfunction", 2, 0)); + }; + defer_drop!(function, vm); + let Some(shape_value) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.fromfunction", 2, 1)); + }; + defer_drop!(shape_value, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.fromfunction", 2, 3)); + } + pos.drop_with_heap(vm); + + let parsed = parse_fromfunction_kwargs(kwargs, vm)?; + defer_drop_mut!(parsed, vm); + let shape = extract_shape_from_value(shape_value, "numpy.fromfunction", vm)?; + let coordinate_args = coordinate_arrays_for_shape(&shape, parsed.dtype, "numpy.fromfunction", vm)?; + let function_kwargs = mem::replace(&mut parsed.extra_kwargs, KwargsValues::Empty); + let function_args = args_from_vec_and_kwargs(coordinate_args, function_kwargs); + vm.evaluate_function("numpy.fromfunction", function, function_args) +} + +/// Parsed keyword state for `fromfunction()`. +/// +/// The dtype is consumed immediately while unknown keywords are preserved and +/// forwarded to the user callable. The custom drop implementation protects +/// forwarded keyword values until they are moved into the callable argument list. +struct ParsedFromFunctionKwargs { + /// Coordinate-array dtype requested by `dtype=`. + dtype: CompactDtype, + /// Keyword arguments that should be passed through to the user callable. + extra_kwargs: KwargsValues, +} + +impl DropWithHeap for ParsedFromFunctionKwargs { + fn drop_with_heap(self, heap: &mut H) { + self.extra_kwargs.drop_with_heap(heap); + } +} + +/// Parses `fromfunction()` keyword arguments, preserving callable kwargs. +fn parse_fromfunction_kwargs( + kwargs: KwargsValues, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + let extra_pairs = Vec::<(Value, Value)>::new(); + let mut extra_guard = HeapGuard::new(extra_pairs, vm); + let (extra_pairs, vm) = extra_guard.as_parts_mut(); + + let mut dtype = CompactDtype::Float64; + let mut dtype_seen = false; + let mut like_seen = false; + for (key, value) in kwargs_iter { + let Some(keyword_name) = key.as_either_str(vm.heap) else { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + if key_str == "dtype" { + key.drop_with_heap(vm); + defer_drop!(value, vm); + if dtype_seen { + return Err(ExcType::type_error_multiple_values("numpy.fromfunction", "dtype")); + } + dtype_seen = true; + dtype = dtype_meta_from_optional_dtype_value(value, "numpy.fromfunction", vm)?; + } else if key_str == "like" { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + if like_seen { + return Err(ExcType::type_error_multiple_values("numpy.fromfunction", "like")); + } + like_seen = true; + } else { + extra_pairs.push((key, value)); + } + } + + let (extra_pairs, vm) = extra_guard.into_parts(); + Ok(ParsedFromFunctionKwargs { + dtype, + extra_kwargs: kwargs_from_pairs(extra_pairs, vm)?, + }) +} + +/// `numpy.fromiter(iter, dtype, count=-1)` — build a one-dimensional numeric array. +fn call_fromiter(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(iterable) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.fromiter", 2, 0)); + }; + let iterable = Some(iterable); + defer_drop_mut!(iterable, vm); + let dtype_pos = pos.next(); + defer_drop_mut!(dtype_pos, vm); + let count_pos = pos.next(); + defer_drop_mut!(count_pos, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.fromiter", 3, 4)); + } + pos.drop_with_heap(vm); + + let parsed = parse_fromiter_kwargs(kwargs, vm)?; + let dtype = match (dtype_pos.as_ref(), parsed.dtype) { + (Some(_), Some(_)) => return Err(ExcType::type_error_multiple_values("numpy.fromiter", "dtype")), + (Some(value), None) => dtype_meta_from_optional_dtype_value(value, "numpy.fromiter", vm)?, + (None, Some(dtype)) => dtype, + (None, None) => { + return Err(SimpleException::new_msg( + ExcType::TypeError, + "fromiter() missing required argument 'dtype' (pos 2)", + ) + .into()); + } + }; + let count = match (count_pos.as_ref(), parsed.count) { + (Some(_), Some(_)) => return Err(ExcType::type_error_multiple_values("numpy.fromiter", "count")), + (Some(value), None) => value_to_i64_arg(value, "numpy.fromiter", "count")?, + (None, Some(count)) => count, + (None, None) => -1, + }; + + let iter = MontyIter::new(iterable.take().expect("fromiter iterable is still owned"), vm)?; + defer_drop_mut!(iter, vm); + let limit = fromiter_count_limit(count, "numpy.fromiter")?; + let capacity = limit.unwrap_or_else(|| iter.size_hint(vm.heap)); + check_array_alloc_size(capacity, vm.heap.tracker())?; + let data = collect_fromiter_data(iter, dtype, limit, vm)?; + let len = data.len(); + let arr = NdArray::new(data, vec![len], ndarray_dtype_from_compact(dtype)); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Parsed keyword state for `fromiter()`. +struct ParsedFromIterKwargs { + /// Optional dtype parsed from `dtype=`. + dtype: Option, + /// Optional element limit parsed from `count=`. + count: Option, +} + +/// Parses the narrow keyword surface supported by `fromiter()`. +fn parse_fromiter_kwargs( + kwargs: KwargsValues, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let mut dtype = None; + let mut count = None; + let mut like_seen = false; + for (key, value) in kwargs_iter { + let Some(keyword_name) = key.as_either_str(vm.heap) else { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + if key_str == "dtype" { + key.drop_with_heap(vm); + defer_drop!(value, vm); + if dtype.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.fromiter", "dtype")); + } + dtype = Some(dtype_meta_from_optional_dtype_value(value, "numpy.fromiter", vm)?); + } else if key_str == "count" { + key.drop_with_heap(vm); + defer_drop!(value, vm); + if count.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.fromiter", "count")); + } + count = Some(value_to_i64_arg(value, "numpy.fromiter", "count")?); + } else if key_str == "like" { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + if like_seen { + return Err(ExcType::type_error_multiple_values("numpy.fromiter", "like")); + } + like_seen = true; + } else { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error(format!( + "'{key_str}' is an invalid keyword argument for numpy.fromiter()" + ))); + } + } + + Ok(ParsedFromIterKwargs { dtype, count }) +} + +/// Converts NumPy's `count` argument into an optional read limit. +fn fromiter_count_limit(count: i64, name: &str) -> RunResult> { + match count.cmp(&-1) { + Ordering::Equal => Ok(None), + Ordering::Less => Ok(Some(0)), + Ordering::Greater => Ok(Some(i64_to_nonnegative_usize(count, name, "count")?)), + } +} + +/// Consumes an iterator into a numeric backing vector for `fromiter()`. +fn collect_fromiter_data( + iter: &mut MontyIter, + dtype: CompactDtype, + limit: Option, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult> { + let mut data = Vec::with_capacity(limit.unwrap_or_else(|| iter.size_hint(vm.heap))); + if let Some(limit) = limit { + for _ in 0..limit { + let Some(item) = iter.for_next(vm)? else { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!( + "iterator too short: Expected {limit} but iterator had only {} items.", + data.len() + ), + ) + .into()); + }; + defer_drop!(item, vm); + data.push(cast_value_to_compact_dtype(item, dtype, "numpy.fromiter", vm)?); + } + } else { + while let Some(item) = iter.for_next(vm)? { + defer_drop!(item, vm); + data.push(cast_value_to_compact_dtype(item, dtype, "numpy.fromiter", vm)?); + } + } + Ok(data) +} + +/// `numpy.fromstring(string, dtype=float, count=-1, sep='')` — parse separated text into a 1-D array. +fn call_fromstring(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(string_value) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.fromstring", 1, 0)); + }; + defer_drop!(string_value, vm); + let dtype_pos = pos.next(); + defer_drop_mut!(dtype_pos, vm); + let count_pos = pos.next(); + defer_drop_mut!(count_pos, vm); + let sep_pos = pos.next(); + defer_drop_mut!(sep_pos, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.fromstring", 1, 4)); + } + pos.drop_with_heap(vm); + + let parsed = parse_fromstring_kwargs(kwargs, vm)?; + let dtype = match (dtype_pos.as_ref(), parsed.dtype) { + (Some(_), Some(_)) => return Err(ExcType::type_error_multiple_values("numpy.fromstring", "dtype")), + (Some(value), None) => dtype_meta_from_optional_dtype_value(value, "numpy.fromstring", vm)?, + (None, Some(dtype)) => dtype, + (None, None) => CompactDtype::Float64, + }; + let count = match (count_pos.as_ref(), parsed.count) { + (Some(_), Some(_)) => return Err(ExcType::type_error_multiple_values("numpy.fromstring", "count")), + (Some(value), None) => value_to_i64_arg(value, "numpy.fromstring", "count")?, + (None, Some(count)) => count, + (None, None) => -1, + }; + let sep = match (sep_pos.as_ref(), parsed.sep) { + (Some(_), Some(_)) => return Err(ExcType::type_error_multiple_values("numpy.fromstring", "sep")), + (Some(value), None) => string_from_value(value, "numpy.fromstring", vm)?, + (None, Some(sep)) => sep, + (None, None) => String::new(), + }; + if sep.is_empty() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "The binary mode of fromstring is removed, use frombuffer instead", + ) + .into()); + } + + let text = string_from_value(string_value, "numpy.fromstring", vm)?; + let limit = fromstring_count_limit(count, "numpy.fromstring")?; + if let Some(limit) = limit { + check_array_alloc_size(limit, vm.heap.tracker())?; + } + let data = parse_fromstring_data(&text, dtype, limit, &sep)?; + if limit.is_none() { + check_array_alloc_size(data.len(), vm.heap.tracker())?; + } + let len = data.len(); + let arr = NdArray::new(data, vec![len], ndarray_dtype_from_compact(dtype)); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Parsed keyword state for `fromstring()`. +struct ParsedFromStringKwargs { + /// Optional dtype parsed from `dtype=`. + dtype: Option, + /// Optional element limit parsed from `count=`. + count: Option, + /// Optional text separator parsed from `sep=`. + sep: Option, +} + +/// Parses the compact keyword surface supported by `fromstring()`. +fn parse_fromstring_kwargs( + kwargs: KwargsValues, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let mut dtype = None; + let mut count = None; + let mut sep = None; + let mut like_seen = false; + for (key, value) in kwargs_iter { + let Some(keyword_name) = key.as_either_str(vm.heap) else { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + if key_str == "dtype" { + key.drop_with_heap(vm); + defer_drop!(value, vm); + if dtype.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.fromstring", "dtype")); + } + dtype = Some(dtype_meta_from_optional_dtype_value(value, "numpy.fromstring", vm)?); + } else if key_str == "count" { + key.drop_with_heap(vm); + defer_drop!(value, vm); + if count.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.fromstring", "count")); + } + count = Some(value_to_i64_arg(value, "numpy.fromstring", "count")?); + } else if key_str == "sep" { + key.drop_with_heap(vm); + defer_drop!(value, vm); + if sep.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.fromstring", "sep")); + } + sep = Some(string_from_value(value, "numpy.fromstring", vm)?); + } else if key_str == "like" { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + if like_seen { + return Err(ExcType::type_error_multiple_values("numpy.fromstring", "like")); + } + like_seen = true; + } else { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error(format!( + "'{key_str}' is an invalid keyword argument for numpy.fromstring()" + ))); + } + } + + Ok(ParsedFromStringKwargs { dtype, count, sep }) +} + +/// Converts NumPy's `count` argument into an optional text-token limit. +fn fromstring_count_limit(count: i64, name: &str) -> RunResult> { + if count < 0 { + Ok(None) + } else { + Ok(Some(i64_to_nonnegative_usize(count, name, "count")?)) + } +} + +/// Parses separated text tokens into Monty's compact ndarray backing values. +fn parse_fromstring_data(text: &str, dtype: CompactDtype, limit: Option, sep: &str) -> RunResult> { + let mut data = Vec::with_capacity(limit.unwrap_or(0)); + if limit == Some(0) { + return Ok(data); + } + + if sep.chars().all(char::is_whitespace) { + for token in text.split_whitespace() { + data.push(parse_fromstring_token(token, dtype)?); + if Some(data.len()) == limit { + break; + } + } + } else { + let mut parts = text.split(sep).peekable(); + while let Some(part) = parts.next() { + if Some(data.len()) == limit { + break; + } + let token = part.trim(); + if token.is_empty() { + if text.trim().is_empty() || parts.peek().is_none() && text.ends_with(sep) { + break; + } + return Err(fromstring_unmatched_data_error()); + } + data.push(parse_fromstring_token(token, dtype)?); + } + } + Ok(data) +} + +/// Parses one numeric text token using the requested compact dtype. +fn parse_fromstring_token(token: &str, dtype: CompactDtype) -> RunResult { + let value = token.parse::().map_err(|_| fromstring_unmatched_data_error())?; + match dtype { + CompactDtype::Bool => Ok(if value == 0.0 { 0.0 } else { 1.0 }), + CompactDtype::Int => Ok(i64_to_f64(f64_to_i64(value))), + CompactDtype::Float32 | CompactDtype::Float64 => Ok(value), + } +} + +/// Error used when text-mode `fromstring()` cannot consume the next token. +fn fromstring_unmatched_data_error() -> RunError { + SimpleException::new_msg( + ExcType::ValueError, + "string or file could not be read to its end due to unmatched data", + ) + .into() +} + +/// Builds the coordinate arrays passed as positional arguments to `fromfunction()`. +fn coordinate_arrays_for_shape( + shape: &[usize], + dtype: CompactDtype, + name: &str, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult> { + let ndim = shape.len(); + let total = checked_shape_product(shape, name)?; + check_array_alloc_size(total.saturating_mul(ndim), vm.heap.tracker())?; + + let mut arrays = Vec::with_capacity(ndim); + if ndim > 0 { + for axis in 0..ndim { + let stride = checked_shape_product(&shape[axis + 1..], name)?; + let mut data = Vec::with_capacity(total); + for flat in 0..total { + let coord = if shape[axis] == 0 { + 0 + } else { + (flat / stride) % shape[axis] + }; + data.push(coordinate_value_for_dtype(coord, dtype)); + } + let arr = NdArray::new(data, shape.to_vec(), ndarray_dtype_from_compact(dtype)); + arrays.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)); + } + } + Ok(arrays) +} + +/// Casts one coordinate value to the compact dtype requested by `fromfunction()`. +fn coordinate_value_for_dtype(coord: usize, dtype: CompactDtype) -> f64 { + match dtype { + CompactDtype::Bool => { + if coord == 0 { + 0.0 + } else { + 1.0 + } + } + CompactDtype::Int | CompactDtype::Float32 | CompactDtype::Float64 => usize_to_f64(coord), + } +} + +/// Wraps positional and keyword values in the compact `ArgValues` representation. +fn args_from_vec_and_kwargs(mut args: Vec, kwargs: KwargsValues) -> ArgValues { + if kwargs.is_empty() { + match args.len() { + 0 => ArgValues::Empty, + 1 => ArgValues::One(args.pop().expect("one positional argument")), + 2 => { + let second = args.pop().expect("second positional argument"); + let first = args.pop().expect("first positional argument"); + ArgValues::Two(first, second) + } + _ => ArgValues::ArgsKargs { + args, + kwargs: KwargsValues::Empty, + }, + } + } else if args.is_empty() { + ArgValues::Kwargs(kwargs) + } else { + ArgValues::ArgsKargs { args, kwargs } + } +} + +/// Creates keyword values from owned `(key, value)` pairs. +fn kwargs_from_pairs(pairs: Vec<(Value, Value)>, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + if pairs.is_empty() { + Ok(KwargsValues::Empty) + } else { + Dict::from_pairs(pairs, vm).map(KwargsValues::Dict) + } +} + +/// Parses NumPy dtype arguments that accept Python type constructors and `None`. +fn dtype_meta_from_optional_dtype_value( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + match value { + Value::None => Ok(CompactDtype::Float64), + Value::Builtin(Builtins::Type(Type::Bool)) => Ok(CompactDtype::Bool), + Value::Builtin(Builtins::Type(Type::Int)) => Ok(CompactDtype::Int), + Value::Builtin(Builtins::Type(Type::Float)) => Ok(CompactDtype::Float64), + _ => dtype_meta_from_dtype_value(value, name, vm), + } +} + +/// Converts one Python scalar into the requested compact ndarray backing value. +fn cast_value_to_compact_dtype( + value: &Value, + dtype: CompactDtype, + _name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + let value = to_f64(value, vm)?; + match dtype { + CompactDtype::Bool => Ok(if value == 0.0 { 0.0 } else { 1.0 }), + CompactDtype::Int => { + #[expect( + clippy::cast_possible_truncation, + reason = "fromiter integer conversion follows NumPy's scalar cast behavior" + )] + let value = value as i64; + Ok(i64_to_f64(value)) + } + CompactDtype::Float32 | CompactDtype::Float64 => Ok(value), + } +} + +/// Maps compact dtype metadata to Monty's current ndarray storage dtype. +fn ndarray_dtype_from_compact(dtype: CompactDtype) -> NdArrayDtype { + match dtype { + CompactDtype::Bool => NdArrayDtype::Bool, + CompactDtype::Int => NdArrayDtype::Int64, + CompactDtype::Float32 | CompactDtype::Float64 => NdArrayDtype::Float64, + } +} + +/// `numpy.array2string(a)` — format an ndarray without the `array(...)` wrapper. +fn call_array2string(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arr = array_display_arg(vm, args, "numpy.array2string")?; + let mut output = String::new(); + arr.array_str_fmt_inner(&mut output)?; + allocate_string(output, vm.heap) +} + +/// `numpy.array_repr(a)` — return the ndarray repr string. +fn call_array_repr(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arr = array_display_arg(vm, args, "numpy.array_repr")?; + let mut output = String::new(); + arr.py_repr_fmt_inner(&mut output)?; + allocate_string(output, vm.heap) +} + +/// `numpy.array_str(a)` — return NumPy's bare ndarray string. +fn call_array_str(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arr = array_display_arg(vm, args, "numpy.array_str")?; + let mut output = String::new(); + arr.array_str_fmt_inner(&mut output)?; + allocate_string(output, vm.heap) +} + +/// `numpy.format_float_positional(x, ...)` — format a real scalar without exponent notation. +fn call_format_float_positional(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (value, options) = parse_float_format_args(args, "numpy.format_float_positional", true, true, vm)?; + allocate_string(format_float_positional_value(value, options), vm.heap) +} + +/// `numpy.format_float_scientific(x, ...)` — format a real scalar with exponent notation. +fn call_format_float_scientific(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (value, options) = parse_float_format_args(args, "numpy.format_float_scientific", false, false, vm)?; + allocate_string(format_float_scientific_value(value, options), vm.heap) +} + +/// Formatting options shared by NumPy's pure float-format helper functions. +/// +/// Monty implements the deterministic scalar subset that is useful for generated +/// Python: precision control, optional leading sign, simple trimming, and the +/// padding knobs NumPy exposes for array-printer internals. Unsupported object +/// formatting and locale concerns are intentionally outside this numeric subset. +#[derive(Clone, Copy)] +struct FloatFormatOptions { + /// Number of fractional or mantissa digits requested by the caller. + precision: Option, + /// Whether trailing insignificant zeros should be removed. + unique: bool, + /// Whether positional precision counts digits after the decimal point. + fractional: bool, + /// NumPy trim mode; Monty models the common keep/remove behavior. + trim: char, + /// Whether positive values should include a leading `+`. + sign: bool, + /// Minimum width to the left of the decimal point. + pad_left: Option, + /// Minimum width to the right of the decimal point for positional format. + pad_right: Option, + /// Minimum number of digits after the decimal point. + min_digits: Option, +} + +impl Default for FloatFormatOptions { + fn default() -> Self { + Self { + precision: None, + unique: true, + fractional: true, + trim: 'k', + sign: false, + pad_left: None, + pad_right: None, + min_digits: None, + } + } +} + +/// Parses the positional and keyword arguments accepted by the float-format helpers. +fn parse_float_format_args( + args: ArgValues, + name: &'static str, + allow_fractional: bool, + allow_pad_right: bool, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult<(f64, FloatFormatOptions)> { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + + let value = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 1, 0))?; + defer_drop!(value, vm); + let precision_pos = pos.next(); + defer_drop_mut!(precision_pos, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most(name, 2, 2 + pos.len())); + } + + let mut options = FloatFormatOptions::default(); + let mut precision_seen = false; + if let Some(precision) = precision_pos.as_ref() { + options.precision = optional_usize_argument(precision, name, "precision")?; + precision_seen = true; + } + + let kwargs = kwargs.into_iter(); + defer_drop_mut!(kwargs, vm); + for (key, value) in kwargs { + defer_drop!(key, vm); + defer_drop!(value, vm); + let Some(keyword_name) = key.as_either_str(vm.heap) else { + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + match key_str { + "precision" => { + if precision_seen { + return Err(ExcType::type_error_multiple_values(name, "precision")); + } + options.precision = optional_usize_argument(value, name, "precision")?; + precision_seen = true; + } + "unique" => options.unique = bool_argument(value, name, "unique")?, + "fractional" if allow_fractional => options.fractional = bool_argument(value, name, "fractional")?, + "trim" => options.trim = trim_argument(value, name, vm)?, + "sign" => options.sign = bool_argument(value, name, "sign")?, + "pad_left" => options.pad_left = optional_usize_argument(value, name, "pad_left")?, + "pad_right" if allow_pad_right => options.pad_right = optional_usize_argument(value, name, "pad_right")?, + "min_digits" => options.min_digits = optional_usize_argument(value, name, "min_digits")?, + _ => return Err(ExcType::type_error_unexpected_keyword(name, key_str)), + } + } + + Ok((to_f64(value, vm)?, options)) +} + +/// Extracts an optional non-negative integer formatting argument. +fn optional_usize_argument(value: &Value, name: &str, arg_name: &str) -> RunResult> { + if matches!(value, Value::None) { + Ok(None) + } else { + value_to_nonnegative_usize(value, name, arg_name).map(Some) + } +} + +/// Extracts a boolean formatting argument. +fn bool_argument(value: &Value, name: &str, arg_name: &str) -> RunResult { + match value { + Value::Bool(value) => Ok(*value), + _ => Err(ExcType::type_error(format!("{name}() {arg_name} must be a bool"))), + } +} + +/// Extracts NumPy's one-character float-format trim mode. +fn trim_argument(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let text = string_from_value(value, name, vm)?; + let mut chars = text.chars(); + let Some(trim) = chars.next() else { + return Err(SimpleException::new_msg(ExcType::ValueError, "Trim mode must not be empty").into()); + }; + if chars.next().is_some() || !matches!(trim, 'k' | '.' | '0' | '-') { + Err(SimpleException::new_msg(ExcType::ValueError, "Trim mode must be one of 'k', '.', '0', '-'").into()) + } else { + Ok(trim) + } +} + +/// Formats one float using NumPy's positional helper subset. +fn format_float_positional_value(value: f64, options: FloatFormatOptions) -> String { + let sign = float_sign_prefix(value, options.sign); + let mut body = if let Some(special) = nonfinite_float_body(value) { + special + } else { + positional_finite_body(value.abs(), options) + }; + body = apply_min_digits(body, options.min_digits, None); + let result = format!("{sign}{body}"); + apply_float_padding(result, options.pad_left, options.pad_right, None) +} + +/// Formats one float using NumPy's scientific helper subset. +fn format_float_scientific_value(value: f64, options: FloatFormatOptions) -> String { + let sign = float_sign_prefix(value, options.sign); + let mut body = if let Some(special) = nonfinite_float_body(value) { + special + } else { + scientific_finite_body(value.abs(), options) + }; + body = apply_min_digits(body, options.min_digits, Some('e')); + let result = format!("{sign}{body}"); + apply_float_padding(result, options.pad_left, None, Some('e')) +} + +/// Returns the sign prefix for a formatted float. +fn float_sign_prefix(value: f64, force_positive: bool) -> &'static str { + if value.is_sign_negative() { + "-" + } else if force_positive { + "+" + } else { + "" + } +} + +/// Returns the stable body for NaN and infinity, if the value is not finite. +fn nonfinite_float_body(value: f64) -> Option { + if value.is_nan() { + Some("nan".to_string()) + } else if value.is_infinite() { + Some("inf".to_string()) + } else { + None + } +} + +/// Formats a finite value without exponent notation. +fn positional_finite_body(value: f64, options: FloatFormatOptions) -> String { + let mut body = if let Some(precision) = options.precision { + let precision = if options.fractional { + precision + } else { + fractional_digits_for_significant_precision(value, precision) + }; + format!("{value:.precision$}") + } else { + value.to_string() + }; + if body.contains('e') || body.contains('E') { + body = format!("{value:.15}"); + } + if options.unique { + trim_float_fraction(body, options.trim) + } else { + ensure_decimal_point(body) + } +} + +/// Converts significant-digit precision into fractional digits for positional output. +fn fractional_digits_for_significant_precision(value: f64, precision: usize) -> usize { + let whole_digits = value.trunc().to_string().trim_start_matches('-').len(); + precision.saturating_sub(whole_digits) +} + +/// Formats a finite value with exponent notation and NumPy-style exponent width. +fn scientific_finite_body(value: f64, options: FloatFormatOptions) -> String { + let raw = if let Some(precision) = options.precision { + format!("{value:.precision$e}") + } else { + format!("{value:e}") + }; + let (mantissa, exponent) = raw + .split_once('e') + .expect("Rust scientific formatting always contains an exponent"); + let mantissa = if options.unique { + trim_float_fraction(mantissa.to_string(), options.trim) + } else { + ensure_decimal_point(mantissa.to_string()) + }; + format!("{mantissa}{}", format_exponent(exponent)) +} + +/// Formats an exponent as `e+00`/`e-00`, matching NumPy's helper output. +fn format_exponent(exponent: &str) -> String { + let value = exponent.parse::().unwrap_or(0); + let sign = if value < 0 { '-' } else { '+' }; + format!("e{sign}{:02}", value.abs()) +} + +/// Ensures a finite float body has a decimal point. +fn ensure_decimal_point(mut body: String) -> String { + if !body.contains('.') { + body.push('.'); + } + body +} + +/// Trims trailing zeros from a finite float body according to NumPy trim modes. +fn trim_float_fraction(mut body: String, trim: char) -> String { + if let Some(dot) = body.find('.') { + while body.ends_with('0') { + body.pop(); + } + if body.len() == dot + 1 && trim == '-' { + body.pop(); + } + } + ensure_decimal_point(body) +} + +/// Pads the fractional part to satisfy `min_digits`. +fn apply_min_digits(mut body: String, min_digits: Option, exponent_marker: Option) -> String { + let Some(min_digits) = min_digits else { + return body; + }; + let exponent_index = exponent_marker.and_then(|marker| body.find(marker)); + let fraction_end = exponent_index.unwrap_or(body.len()); + if body[..fraction_end].find('.').is_none() { + body.insert(fraction_end, '.'); + } + let dot = body[..fraction_end].find('.').expect("decimal point inserted above"); + let digits = fraction_end.saturating_sub(dot + 1); + if digits < min_digits { + let zeros = "0".repeat(min_digits - digits); + body.insert_str(fraction_end, &zeros); + } + body +} + +/// Applies NumPy's left/right padding controls to an already formatted float string. +fn apply_float_padding( + mut text: String, + pad_left: Option, + pad_right: Option, + exponent_marker: Option, +) -> String { + if let Some(width) = pad_left { + let left_len = float_left_width(&text, exponent_marker); + if width > left_len { + text.insert_str(0, &" ".repeat(width - left_len)); + } + } + if let Some(width) = pad_right { + let right_len = float_right_width(&text, exponent_marker); + if width > right_len { + text.push_str(&" ".repeat(width - right_len)); + } + } + text +} + +/// Counts characters before the decimal point or exponent for `pad_left`. +fn float_left_width(text: &str, exponent_marker: Option) -> usize { + let end = exponent_marker + .and_then(|marker| text.find(marker)) + .unwrap_or(text.len()); + text[..end].find('.').unwrap_or(end) +} + +/// Counts characters after the decimal point and before the exponent for `pad_right`. +fn float_right_width(text: &str, exponent_marker: Option) -> usize { + let end = exponent_marker + .and_then(|marker| text.find(marker)) + .unwrap_or(text.len()); + text[..end].find('.').map_or(0, |dot| end - dot - 1) +} + +/// Extracts the ndarray argument for display helpers and ignores optional print-only arguments. +fn array_display_arg(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, name: &str) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(arg) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least(name, 1, 0)); + }; + defer_drop!(arg, vm); + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + ndarray_from_value(arg, name, vm) +} + +/// `numpy.zeros(shape)` — create an array of zeros with the given shape. +/// +/// Accepts an integer for 1D or a tuple/list for multi-dimensional shapes. +fn call_zeros(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.zeros", vm.heap)?; + let shape = extract_shape(arg, "numpy.zeros", vm)?; + let total: usize = shape.iter().product(); + check_array_alloc_size(total, vm.heap.tracker())?; + let arr = NdArray::new(vec![0.0; total], shape, NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.ones(shape)` — create an array of ones with the given shape. +/// +/// Accepts an integer for 1D or a tuple/list for multi-dimensional shapes. +fn call_ones(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.ones", vm.heap)?; + let shape = extract_shape(arg, "numpy.ones", vm)?; + let total: usize = shape.iter().product(); + check_array_alloc_size(total, vm.heap.tracker())?; + let arr = NdArray::new(vec![1.0; total], shape, NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.arange([start,] stop[, step])` — evenly spaced values within a range. +/// +/// Supports 1, 2, or 3 arguments matching NumPy's behavior: +/// - `arange(stop)` — values from 0 to stop with step 1 +/// - `arange(start, stop)` — values from start to stop with step 1 +/// - `arange(start, stop, step)` — values from start to stop with given step +fn call_arange(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.arange", vm.heap)?; + defer_drop_mut!(pos, vm); + + let first = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.arange() requires at least 1 argument"))?; + defer_drop!(first, vm); + let second = pos.next(); + let third = pos.next(); + + for extra in pos { + extra.drop_with_heap(vm); + } + + let (start, stop, step) = match (&second, &third) { + (None, None) => (0.0, to_f64(first, vm)?, 1.0), + (Some(stop_val), None) => (to_f64(first, vm)?, to_f64(stop_val, vm)?, 1.0), + (Some(stop_val), Some(step_val)) => (to_f64(first, vm)?, to_f64(stop_val, vm)?, to_f64(step_val, vm)?), + (None, Some(_)) => unreachable!("third arg without second"), + }; + + second.drop_with_heap(vm); + third.drop_with_heap(vm); + + if step == 0.0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "step must not be zero").into()); + } + + // Pre-check allocation size before building the Vec. + // Estimate element count the same way NumPy does: ceil((stop - start) / step). + let estimated_len = ((stop - start) / step).ceil().max(0.0); + #[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "estimated_len is non-negative and capped by usize::MAX" + )] + let estimated_count = if estimated_len.is_finite() { + (estimated_len as u64).min(usize::MAX as u64) as usize + } else { + 0 + }; + check_array_alloc_size(estimated_count, vm.heap.tracker())?; + + let mut data = Vec::new(); + let mut val = start; + if step > 0.0 { + while val < stop { + data.push(val); + val += step; + } + } else { + while val > stop { + data.push(val); + val += step; + } + } + + let has_float = start.fract() != 0.0 || stop.fract() != 0.0 || step.fract() != 0.0; + let dtype = if has_float { + NdArrayDtype::Float64 + } else { + NdArrayDtype::Int64 + }; + let len = data.len(); + let arr = NdArray::new(data, vec![len], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.linspace(start, stop, num)` — evenly spaced values over an interval. +/// +/// Returns `num` values including both endpoints. +fn call_linspace(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.linspace", vm.heap)?; + defer_drop_mut!(pos, vm); + + let start_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.linspace() requires 3 arguments"))?; + defer_drop!(start_val, vm); + let stop_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.linspace() requires 3 arguments"))?; + defer_drop!(stop_val, vm); + let num_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.linspace() requires 3 arguments"))?; + defer_drop!(num_val, vm); + + for extra in pos { + extra.drop_with_heap(vm); + } + + let start = to_f64(start_val, vm)?; + let stop = to_f64(stop_val, vm)?; + let num = match num_val { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "num is checked non-negative above" + )] + Value::Int(n) => { + if *n < 0 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "Number of samples, num, must be non-negative.", + ) + .into()); + } + *n as usize + } + _ => { + return Err(ExcType::type_error("num must be an integer")); + } + }; + + check_array_alloc_size(num, vm.heap.tracker())?; + + let data = if num == 0 { + Vec::new() + } else if num == 1 { + vec![start] + } else { + let step = (stop - start) / (num - 1) as f64; + (0..num).map(|i| start + step * i as f64).collect() + }; + + let len = data.len(); + let arr = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +// =========================== +// Aggregate functions +// =========================== + +/// Helper for aggregate functions like `numpy.sum(a)` that return a float. +/// +/// Accepts both ndarray and plain list arguments — lists are auto-converted to +/// a temporary NdArray, matching real NumPy's behavior of `np.mean([1,2,3])`. +fn call_aggregate( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(&NdArray) -> f64, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + return Err(ExcType::type_error(format!( + "{name}() requires an array or list argument" + ))); + }; + match vm.heap.get(*heap_id) { + HeapData::NdArray(arr) => Ok(Value::Float(f(arr))), + HeapData::List(list) => { + let tmp = list_to_ndarray(list, name)?; + Ok(Value::Float(f(&tmp))) + } + _ => Err(ExcType::type_error(format!( + "{name}() requires an array or list argument" + ))), + } +} + +/// Helper for aggregate functions that can fail (min/max on empty arrays). +/// +/// Accepts both ndarray and plain list arguments — lists are auto-converted. +fn call_aggregate_result( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(&NdArray) -> RunResult, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + return match arg { + Value::Bool(_) | Value::Int(_) | Value::Float(_) => Ok(arg.clone_immediate()), + _ => Err(ExcType::type_error(format!( + "{name}() requires an array, list, or scalar argument" + ))), + }; + }; + match vm.heap.get(*heap_id) { + HeapData::NdArray(arr) => Ok(Value::Float(f(arr)?)), + HeapData::List(list) => { + let tmp = list_to_ndarray(list, name)?; + Ok(Value::Float(f(&tmp)?)) + } + _ => Err(ExcType::type_error(format!( + "{name}() requires an array or list argument" + ))), + } +} + +/// Converts a `List` of numeric values to a 1-D `NdArray`. +/// +/// Used by aggregate functions to accept plain lists like `np.mean([1, 2, 3])` +/// in addition to ndarray arguments. +fn list_to_ndarray(list: &List, name: &str) -> RunResult { + let mut has_float = false; + let mut has_int = false; + let mut has_bool = false; + let data: Vec = list + .as_slice() + .iter() + .map(|v| match v { + Value::Int(i) => { + has_int = true; + Ok(*i as f64) + } + Value::Float(f) => { + has_float = true; + Ok(*f) + } + Value::Bool(b) => { + has_bool = true; + Ok(if *b { 1.0 } else { 0.0 }) + } + _ => Err(ExcType::type_error(format!("{name}() list elements must be numeric"))), + }) + .collect::>>()?; + let len = data.len(); + let dtype = if has_float { + NdArrayDtype::Float64 + } else if has_int { + NdArrayDtype::Int64 + } else if has_bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Float64 + }; + Ok(NdArray::new(data, vec![len], dtype)) +} + +// =========================== +// Element-wise functions +// =========================== + +/// Helper for element-wise unary functions like `numpy.abs(a)`, `numpy.sqrt(a)`, etc. +/// +/// Accepts both ndarray and plain list arguments — lists are auto-converted to +/// a temporary NdArray, matching real NumPy's behavior of `np.abs([1, -2, 3])`. +fn call_elementwise( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(f64) -> f64, + name: &str, + result_dtype: Option, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + let (value, source_dtype) = numeric_scalar_info(arg, name, vm)?; + let dtype = result_dtype.unwrap_or(source_dtype); + return Ok(scalar_from_f64(f(value), dtype)); + }; + let (data, shape, source_dtype) = match vm.heap.get(*heap_id) { + HeapData::NdArray(arr) => ( + arr.data().iter().map(|&v| f(v)).collect::>(), + arr.shape().to_vec(), + arr.dtype(), + ), + HeapData::List(list) => { + let tmp = list_to_ndarray(list, name)?; + let data = tmp.data().iter().map(|&v| f(v)).collect::>(); + let shape = tmp.shape().to_vec(); + let dtype = tmp.dtype(); + (data, shape, dtype) + } + _ => { + return Err(ExcType::type_error(format!( + "{name}() requires an array or list argument" + ))); + } + }; + let dtype = result_dtype.unwrap_or(source_dtype); + let new_arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// `numpy.round(a, decimals=0)` — element-wise rounding. +fn call_round(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, decimals_val) = args.get_one_two_args("numpy.round", vm.heap)?; + defer_drop!(arr_val, vm); + + let decimals = match decimals_val { + #[expect(clippy::cast_possible_truncation, reason = "decimals value from user input")] + Some(Value::Int(n)) => n as i32, + Some(other) => { + other.drop_with_heap(vm); + return Err(ExcType::type_error("decimals must be an integer")); + } + None => 0, + }; + + let factor = 10f64.powi(decimals); + let Value::Ref(heap_id) = arr_val else { + let (value, _) = numeric_scalar_info(arr_val, "numpy.round", vm)?; + return Ok(Value::Float(round_to_decimals(value, factor))); + }; + let (data, shape) = match vm.heap.get(*heap_id) { + HeapData::NdArray(arr) => ( + arr.data() + .iter() + .map(|&v| round_to_decimals(v, factor)) + .collect::>(), + arr.shape().to_vec(), + ), + HeapData::List(_) => { + let arr = ndarray_from_list(arr_val, vm.heap)?; + ( + arr.data() + .iter() + .map(|&v| round_to_decimals(v, factor)) + .collect::>(), + arr.shape().to_vec(), + ) + } + _ => { + return Err(ExcType::type_error( + "numpy.round() requires an array, list, or scalar argument", + )); + } + }; + + let new_arr = NdArray::new(data, shape, NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// `numpy.clip(a, a_min, a_max)` — clip (limit) array values to a range. +fn call_clip(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.clip", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.clip() requires 3 arguments"))?; + defer_drop!(arr_val, vm); + let min_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.clip() requires 3 arguments"))?; + defer_drop!(min_val, vm); + let max_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.clip() requires 3 arguments"))?; + defer_drop!(max_val, vm); + + for extra in pos { + extra.drop_with_heap(vm); + } + + let a_min = to_f64(min_val, vm)?; + let a_max = to_f64(max_val, vm)?; + + let Value::Ref(heap_id) = arr_val else { + return Err(ExcType::type_error( + "numpy.clip() requires an ndarray as the first argument", + )); + }; + let HeapData::NdArray(arr) = vm.heap.get(*heap_id) else { + return Err(ExcType::type_error( + "numpy.clip() requires an ndarray as the first argument", + )); + }; + + let data: Vec = arr.data().iter().map(|&v| v.clamp(a_min, a_max)).collect(); + let dtype = arr.dtype(); + let shape = arr.shape().to_vec(); + + let new_arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// `numpy.where(condition, x, y)` — conditional element selection. +fn call_where(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.where", vm.heap)?; + defer_drop_mut!(pos, vm); + + let cond_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.where() requires 3 arguments"))?; + defer_drop!(cond_val, vm); + let x_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.where() requires 3 arguments"))?; + defer_drop!(x_val, vm); + let y_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.where() requires 3 arguments"))?; + defer_drop!(y_val, vm); + + for extra in pos { + extra.drop_with_heap(vm); + } + + let cond_arr = ndarray_or_scalar_from_value(cond_val, "numpy.where", vm)?; + let x_arr = ndarray_or_scalar_from_value(x_val, "numpy.where", vm)?; + let y_arr = ndarray_or_scalar_from_value(y_val, "numpy.where", vm)?; + let shape = broadcast_shape(&[cond_arr.shape(), x_arr.shape(), y_arr.shape()], "numpy.where")?; + let cond_data = broadcast_array_data( + cond_arr.data(), + cond_arr.shape(), + &shape, + "numpy.where", + vm.heap.tracker(), + )?; + let x_data = broadcast_array_data(x_arr.data(), x_arr.shape(), &shape, "numpy.where", vm.heap.tracker())?; + let y_data = broadcast_array_data(y_arr.data(), y_arr.shape(), &shape, "numpy.where", vm.heap.tracker())?; + + let data: Vec = cond_data + .iter() + .zip(x_data.iter().zip(y_data.iter())) + .map(|(&c, (&x, &y))| if c == 0.0 { y } else { x }) + .collect(); + + let dtype = promote_dtype(x_arr.dtype(), y_arr.dtype()); + + let new_arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// Helper for element-wise binary functions like `numpy.maximum(a, b)`. +fn call_pairwise( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(f64, f64) -> f64, + name: &str, +) -> RunResult { + call_numeric_binop(vm, args, f, name, BinopResult::Promoted) +} + +/// Result dtype policy for NumPy binary ufunc-style helpers. +#[derive(Clone, Copy)] +enum BinopResult { + /// Preserve NumPy-like int/float promotion for arithmetic operations. + Promoted, + /// Force float output, as true division does. + Float, + /// Force boolean output for comparison ufuncs. + Bool, +} + +/// Shared implementation for common binary NumPy ufuncs. +/// +/// Supports ndarray, list, and scalar inputs with NumPy-style shape broadcasting. +fn call_numeric_binop( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: impl Fn(f64, f64) -> f64, + name: &str, + result: BinopResult, +) -> RunResult { + let (a_val, b_val) = args.get_two_args(name, vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + + let a_arr = ndarray_or_scalar_from_value(a_val, name, vm)?; + let b_arr = ndarray_or_scalar_from_value(b_val, name, vm)?; + let dtype = binop_dtype(result, a_arr.dtype(), b_arr.dtype()); + + if a_arr.shape().is_empty() && b_arr.shape().is_empty() { + Ok(scalar_from_f64(f(a_arr.data()[0], b_arr.data()[0]), dtype)) + } else { + let (left, right, shape) = broadcast_pair_data( + a_arr.data(), + a_arr.shape(), + b_arr.data(), + b_arr.shape(), + name, + vm.heap.tracker(), + )?; + let data: Vec = left.iter().zip(right.iter()).map(|(&a, &b)| f(a, b)).collect(); + let arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } +} + +/// Computes the dtype for a binary ufunc result from the operation policy. +fn binop_dtype(result: BinopResult, a: NdArrayDtype, b: NdArrayDtype) -> NdArrayDtype { + match result { + BinopResult::Promoted => promote_dtype(a, b), + BinopResult::Float => NdArrayDtype::Float64, + BinopResult::Bool => NdArrayDtype::Bool, + } +} + +/// Python-compatible modulo: result has the same sign as the divisor. +fn py_mod(a: f64, b: f64) -> f64 { + let r = a % b; + if r != 0.0 && ((r > 0.0) != (b > 0.0)) { r + b } else { r } +} + +/// Converts a boolean comparison result to the f64 backing value for bool arrays. +fn bool_to_f64(value: bool) -> f64 { + if value { 1.0 } else { 0.0 } +} + +/// Equality comparison for NumPy-style numeric ufuncs. +/// +/// `partial_cmp` preserves NumPy's NaN behavior without using direct float +/// equality: NaN does not compare equal to itself. +fn eq_to_f64(a: f64, b: f64) -> f64 { + bool_to_f64(a.partial_cmp(&b) == Some(Ordering::Equal)) +} + +/// Inequality comparison for NumPy-style numeric ufuncs. +fn ne_to_f64(a: f64, b: f64) -> f64 { + bool_to_f64(a.partial_cmp(&b) != Some(Ordering::Equal)) +} + +// =========================== +// Sorting and unique functions +// =========================== + +/// `numpy.sort(a)` — return a sorted copy of the array. +fn call_sort(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.sort", vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + return Err(ExcType::type_error("numpy.sort() requires an ndarray argument")); + }; + let HeapData::NdArray(arr) = vm.heap.get(*heap_id) else { + return Err(ExcType::type_error("numpy.sort() requires an ndarray argument")); + }; + + let mut data = arr.data().to_vec(); + let dtype = arr.dtype(); + let shape = arr.shape().to_vec(); + data.sort_by(nan_last_cmp); + let new_arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// `numpy.unique(a)` — return the sorted unique elements of an array. +fn call_unique(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.unique", vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + return Err(ExcType::type_error("numpy.unique() requires an ndarray argument")); + }; + let HeapData::NdArray(arr) = vm.heap.get(*heap_id) else { + return Err(ExcType::type_error("numpy.unique() requires an ndarray argument")); + }; + + let mut data = arr.data().to_vec(); + let dtype = arr.dtype(); + data.sort_by(nan_last_cmp); + data.dedup(); + let len = data.len(); + let new_arr = NdArray::new(data, vec![len], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// Return shape for the Array API style `unique_*` helpers. +#[derive(Clone, Copy)] +enum UniqueResultKind { + /// `unique_values(x)` returns only the unique values ndarray. + Values, + /// `unique_counts(x)` returns values plus occurrence counts. + Counts, + /// `unique_inverse(x)` returns values plus inverse indices. + Inverse, + /// `unique_all(x)` returns values, first indices, inverse indices, and counts. + All, +} + +/// Precomputed unique-result arrays shared by the `unique_*` wrappers. +struct UniqueAnalysis { + /// Sorted unique values. + values: Vec, + /// First original index for each unique value. + first_indices: Vec, + /// Inverse index for each input element. + inverse_indices: Vec, + /// Occurrence count for each unique value. + counts: Vec, +} + +/// Shared implementation for NumPy's Array API `unique_*` helpers. +fn call_unique_result( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + kind: UniqueResultKind, +) -> RunResult { + let name = match kind { + UniqueResultKind::Values => "numpy.unique_values", + UniqueResultKind::Counts => "numpy.unique_counts", + UniqueResultKind::Inverse => "numpy.unique_inverse", + UniqueResultKind::All => "numpy.unique_all", + }; + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, name, vm)?; + let analysis = unique_analysis(&arr); + + match kind { + UniqueResultKind::Values => allocate_unique_values_array(&analysis, arr.dtype(), vm.heap), + UniqueResultKind::Counts => { + let values = allocate_unique_values_array(&analysis, arr.dtype(), vm.heap)?; + let counts = allocate_usize_array(&analysis.counts, vec![analysis.counts.len()], vm.heap)?; + allocate_namedtuple_result("UniqueCountsResult", &["values", "counts"], vec![values, counts], vm) + } + UniqueResultKind::Inverse => { + let values = allocate_unique_values_array(&analysis, arr.dtype(), vm.heap)?; + let inverse_indices = allocate_usize_array(&analysis.inverse_indices, arr.shape().to_vec(), vm.heap)?; + allocate_namedtuple_result( + "UniqueInverseResult", + &["values", "inverse_indices"], + vec![values, inverse_indices], + vm, + ) + } + UniqueResultKind::All => { + let values = allocate_unique_values_array(&analysis, arr.dtype(), vm.heap)?; + let indices = allocate_usize_array(&analysis.first_indices, vec![analysis.first_indices.len()], vm.heap)?; + let inverse_indices = allocate_usize_array(&analysis.inverse_indices, arr.shape().to_vec(), vm.heap)?; + let counts = allocate_usize_array(&analysis.counts, vec![analysis.counts.len()], vm.heap)?; + allocate_namedtuple_result( + "UniqueAllResult", + &["values", "indices", "inverse_indices", "counts"], + vec![values, indices, inverse_indices, counts], + vm, + ) + } + } +} + +/// Computes sorted unique values, first indices, inverse indices, and counts. +fn unique_analysis(arr: &NdArray) -> UniqueAnalysis { + let mut pairs: Vec<(f64, usize)> = arr + .data() + .iter() + .copied() + .enumerate() + .map(|(index, value)| (value, index)) + .collect(); + pairs.sort_by(|(left, _), (right, _)| nan_last_cmp(left, right)); + + let mut values = Vec::new(); + let mut first_indices = Vec::new(); + let mut inverse_indices = vec![0; arr.len()]; + let mut counts = Vec::new(); + + for (value, original_index) in pairs { + let group = if values.last().is_some_and(|last| f64_exact_equal(*last, value)) { + values.len() - 1 + } else { + values.push(value); + first_indices.push(original_index); + counts.push(0); + values.len() - 1 + }; + first_indices[group] = first_indices[group].min(original_index); + counts[group] += 1; + inverse_indices[original_index] = group; + } + + UniqueAnalysis { + values, + first_indices, + inverse_indices, + counts, + } +} + +/// Allocates the unique values ndarray. +fn allocate_unique_values_array( + analysis: &UniqueAnalysis, + dtype: NdArrayDtype, + heap: &Heap, +) -> RunResult { + let values = analysis.values.clone(); + let result = NdArray::new(values, vec![analysis.values.len()], dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) +} + +/// Allocates an int64 ndarray from usize values. +fn allocate_usize_array(values: &[usize], shape: Vec, heap: &Heap) -> RunResult { + let data = values.iter().copied().map(usize_to_f64).collect(); + let result = NdArray::new(data, shape, NdArrayDtype::Int64); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) +} + +/// Allocates a namedtuple-style result object for `unique_*` helpers. +fn allocate_namedtuple_result( + type_name: &str, + fields: &[&str], + values: Vec, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let field_names = fields.iter().map(|field| (*field).to_owned().into()).collect(); + let result = NamedTuple::new(type_name.to_owned(), field_names, values); + Ok(Value::Ref(vm.heap.allocate(HeapData::NamedTuple(result))?)) +} + +/// `numpy.concatenate(arrays)` — join a sequence of arrays along the first axis. +fn call_concatenate(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.concatenate", vm.heap)?; + defer_drop!(arg, vm); + + let Value::Ref(list_id) = arg else { + return Err(ExcType::type_error("numpy.concatenate() requires a list of arrays")); + }; + let arr_ids: Vec = { + let HeapData::List(list) = vm.heap.get(*list_id) else { + return Err(ExcType::type_error("numpy.concatenate() requires a list of arrays")); + }; + let mut ids = Vec::new(); + for v in list.as_slice() { + let Value::Ref(id) = v else { + return Err(ExcType::type_error( + "numpy.concatenate() requires all elements to be ndarrays", + )); + }; + ids.push(*id); + } + ids + }; + + let mut total_len: usize = 0; + let mut result_dtype = NdArrayDtype::Int64; + + for arr_id in &arr_ids { + let HeapData::NdArray(arr) = vm.heap.get(*arr_id) else { + return Err(ExcType::type_error( + "numpy.concatenate() requires all elements to be ndarrays", + )); + }; + total_len = total_len.saturating_add(arr.data().len()); + result_dtype = promote_dtype(result_dtype, arr.dtype()); + } + + check_array_alloc_size(total_len, vm.heap.tracker())?; + + let mut combined_data = Vec::with_capacity(total_len); + for arr_id in &arr_ids { + let HeapData::NdArray(arr) = vm.heap.get(*arr_id) else { + unreachable!("already validated above"); + }; + combined_data.extend_from_slice(arr.data()); + } + + let len = combined_data.len(); + let new_arr = NdArray::new(combined_data, vec![len], result_dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// `numpy.cumsum(a)` — return the cumulative sum of array elements. +fn call_cumsum(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.cumsum", vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + return Err(ExcType::type_error("numpy.cumsum() requires an ndarray argument")); + }; + let HeapData::NdArray(arr) = vm.heap.get(*heap_id) else { + return Err(ExcType::type_error("numpy.cumsum() requires an ndarray argument")); + }; + + let src = arr.data(); + let dtype = arr.dtype(); + let mut data = Vec::with_capacity(src.len()); + let mut running = 0.0; + for &v in src { + running += v; + data.push(running); + } + let len = data.len(); + let new_arr = NdArray::new(data, vec![len], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(new_arr))?)) +} + +/// `numpy.dot(a, b)` — dot product of two 1D arrays. +fn call_dot(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.dot", vm.heap)?; + defer_drop!(a_val, vm); + + let Value::Ref(a_id) = a_val else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.dot() requires ndarray arguments")); + }; + let Value::Ref(b_id) = &b_val else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.dot() requires ndarray arguments")); + }; + let b_id = *b_id; + + let HeapData::NdArray(a_arr) = vm.heap.get(*a_id) else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.dot() requires ndarray arguments")); + }; + let a_data: Vec = a_arr.data().to_vec(); + let a_dtype = a_arr.dtype(); + + let HeapData::NdArray(b_arr) = vm.heap.get(b_id) else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.dot() requires ndarray arguments")); + }; + + if a_data.len() != b_arr.data().len() { + b_val.drop_with_heap(vm); + return Err(SimpleException::new_msg(ExcType::ValueError, "shapes are not aligned for dot product").into()); + } + + let result: f64 = a_data.iter().zip(b_arr.data().iter()).map(|(&a, &b)| a * b).sum(); + let b_dtype = b_arr.dtype(); + b_val.drop_with_heap(vm); + + let value = if a_dtype == NdArrayDtype::Int64 && b_dtype == NdArrayDtype::Int64 { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is intended for int dot product" + )] + Value::Int(result as i64) + } else { + Value::Float(result) + }; + Ok(value) +} + +/// `numpy.matmul(a, b)` — matrix multiplication (like `a @ b`). +/// +/// Supports 1D-1D (dot product), 2D-2D (matrix multiply), 2D-1D and 1D-2D products. +fn call_matmul(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.matmul", vm.heap)?; + defer_drop!(a_val, vm); + + let Value::Ref(a_id) = a_val else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.matmul() requires ndarray arguments")); + }; + let Value::Ref(b_id) = &b_val else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.matmul() requires ndarray arguments")); + }; + let b_id = *b_id; + + let HeapData::NdArray(a_arr) = vm.heap.get(*a_id) else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.matmul() requires ndarray arguments")); + }; + let HeapData::NdArray(b_arr) = vm.heap.get(b_id) else { + b_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.matmul() requires ndarray arguments")); + }; + + let result = a_arr.matmul(b_arr, vm.heap); + b_val.drop_with_heap(vm); + result +} + +// =========================== +// Element-wise math, array creation, testing, aggregation, manipulation, +// search, and utility functions +// =========================== + +/// `numpy.power(a, b)` — element-wise power (like `a ** b`). +/// +/// Supports scalar, list, and ndarray inputs with NumPy-style broadcasting. +fn call_power(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + call_numeric_binop(vm, args, f64::powf, "numpy.power", BinopResult::Promoted) +} + +/// `numpy.diff(a)` — first-order discrete difference: `a[1:] - a[:-1]`. +fn call_diff(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.diff", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.diff", vm)?; + if arr.len() <= 1 { + let result = NdArray::new(Vec::new(), vec![0], arr.dtype()); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + let data: Vec = arr.data().windows(2).map(|w| w[1] - w[0]).collect(); + let len = data.len(); + let arr = NdArray::new(data, vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.ediff1d(a)` — flattened first-order difference. +fn call_ediff1d(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.ediff1d", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.ediff1d", vm)?; + if arr.len() <= 1 { + let result = NdArray::new(Vec::new(), vec![0], arr.dtype()); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + let data: Vec = arr.data().windows(2).map(|w| w[1] - w[0]).collect(); + let len = data.len(); + let arr = NdArray::new(data, vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.full(shape, fill_value)` — create an array filled with a constant. +fn call_full(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (shape_val, fill_val) = args.get_two_args("numpy.full", vm.heap)?; + defer_drop!(shape_val, vm); + let shape = extract_shape(shape_val.clone_immediate(), "numpy.full", vm)?; + let (fill, dtype) = match fill_val { + Value::Int(n) => (n as f64, NdArrayDtype::Int64), + Value::Float(f) => (f, NdArrayDtype::Float64), + Value::Bool(b) => (if b { 1.0 } else { 0.0 }, NdArrayDtype::Bool), + other => { + other.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.full() fill_value must be numeric")); + } + }; + let total: usize = shape.iter().product(); + check_array_alloc_size(total, vm.heap.tracker())?; + let arr = NdArray::new(vec![fill; total], shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.eye(n)` — create an n×n identity matrix (Float64). +fn call_eye(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.eye", vm.heap)?; + let n = extract_size(arg, "numpy.eye", vm)?; + check_array_alloc_size(n * n, vm.heap.tracker())?; + let mut data = vec![0.0; n * n]; + for i in 0..n { + data[i * n + i] = 1.0; + } + let arr = NdArray::new(data, vec![n, n], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.copy(a)` — return a copy of the array, also accepts plain lists. +fn call_copy(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.copy", vm.heap)?; + defer_drop!(arg, vm); + let Value::Ref(heap_id) = arg else { + return Err(ExcType::type_error("numpy.copy() requires an array or list")); + }; + let result = match vm.heap.get(*heap_id) { + HeapData::NdArray(arr) => NdArray::new(arr.data().to_vec(), arr.shape().to_vec(), arr.dtype()), + HeapData::List(_) => { + // Use ndarray_from_list which handles proper dtype tracking + ndarray_from_list(arg, vm.heap)? + } + _ => return Err(ExcType::type_error("numpy.copy() requires an array or list")), + }; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.empty(shape)` — create an uninitialized array (returns zeros in Monty). +fn call_empty(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.empty", vm.heap)?; + let shape = extract_shape(arg, "numpy.empty", vm)?; + let total: usize = shape.iter().product(); + check_array_alloc_size(total, vm.heap.tracker())?; + let arr = NdArray::new(vec![0.0; total], shape, NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Helper for `numpy.zeros_like(a)` and `numpy.ones_like(a)`. +fn call_like(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, fill: f64, name: &str) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, name, vm)?; + let total = arr.len(); + let result = NdArray::new(vec![fill; total], arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Helper for element-wise boolean test functions like `numpy.isnan`, `numpy.isinf`, etc. +/// +/// Applies the predicate to each element and returns a Bool dtype array. +fn call_bool_test( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + pred: fn(f64) -> bool, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, name, vm)?; + let data: Vec = arr.data().iter().map(|&v| if pred(v) { 1.0 } else { 0.0 }).collect(); + let result = NdArray::new(data, arr.shape().to_vec(), NdArrayDtype::Bool); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Predicate for positive infinity. +fn f64_is_pos_inf(value: f64) -> bool { + value.is_infinite() && value.is_sign_positive() +} + +/// Predicate for negative infinity. +fn f64_is_neg_inf(value: f64) -> bool { + value.is_infinite() && value.is_sign_negative() +} + +/// `numpy.array_equal(a, b)` — true if two arrays have same shape and elements. +/// +/// Uses direct f64 equality, so `NaN != NaN` — matching NumPy's behavior. +fn call_array_equal(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.array_equal", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + + let a_arr = ndarray_from_value(a_val, "numpy.array_equal", vm)?; + let b_arr = ndarray_from_value(b_val, "numpy.array_equal", vm)?; + + let equal = a_arr.shape() == b_arr.shape() && a_arr.data() == b_arr.data(); + Ok(Value::Bool(equal)) +} + +/// `numpy.array_equiv(a, b)` — equality with NumPy-style broadcasting. +fn call_array_equiv(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.array_equiv", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + let a = ndarray_or_scalar_from_value(a_val, "numpy.array_equiv", vm)?; + let b = ndarray_or_scalar_from_value(b_val, "numpy.array_equiv", vm)?; + match broadcast_pair_data( + a.data(), + a.shape(), + b.data(), + b.shape(), + "numpy.array_equiv", + vm.heap.tracker(), + ) { + Ok((left, right, _)) => Ok(Value::Bool( + left.iter().zip(right.iter()).all(|(&a, &b)| f64_exact_equal(a, b)), + )), + Err(_) => Ok(Value::Bool(false)), + } +} + +/// Exact float equality with NumPy's NaN-is-not-equal behavior and clippy-friendly spelling. +fn f64_exact_equal(a: f64, b: f64) -> bool { + a.partial_cmp(&b) == Some(Ordering::Equal) +} + +/// `numpy.count_nonzero(a)` — count non-zero elements. +fn call_count_nonzero(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.count_nonzero", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.count_nonzero", vm)?; + #[expect(clippy::cast_possible_wrap, reason = "count won't exceed i64::MAX")] + let count = arr.data().iter().filter(|&&v| v != 0.0).count() as i64; + Ok(Value::Int(count)) +} + +/// `numpy.all(a)` — true if all elements are truthy (module-level wrapper). +fn call_all(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.all", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.all", vm)?; + Ok(Value::Bool(arr.all())) +} + +/// `numpy.any(a)` — true if any element is truthy (module-level wrapper). +fn call_any(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.any", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.any", vm)?; + Ok(Value::Bool(arr.any())) +} + +/// `numpy.prod(a)` — product of array elements. +fn call_prod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.prod", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.prod", vm)?; + let product = arr.prod(); + match arr.dtype() { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is intended for int prod" + )] + NdArrayDtype::Int64 => Ok(Value::Int(product as i64)), + NdArrayDtype::Float64 | NdArrayDtype::Bool => Ok(Value::Float(product)), + } +} + +/// `numpy.median(a)` — median of array elements. +fn call_median(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.median", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.median", vm)?; + if arr.len() == 0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "zero-size array has no median").into()); + } + let mut sorted = arr.data().to_vec(); + sorted.sort_by(nan_last_cmp); + let mid = sorted.len() / 2; + let median = if sorted.len() % 2 == 0 { + f64::midpoint(sorted[mid - 1], sorted[mid]) + } else { + sorted[mid] + }; + Ok(Value::Float(median)) +} + +/// `numpy.argmin(a)` — index of minimum element (module-level wrapper). +fn call_argmin_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.argmin", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.argmin", vm)?; + #[expect(clippy::cast_possible_wrap, reason = "array index won't exceed i64::MAX")] + Ok(Value::Int(arr.argmin()? as i64)) +} + +/// `numpy.argmax(a)` — index of maximum element (module-level wrapper). +fn call_argmax_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.argmax", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.argmax", vm)?; + #[expect(clippy::cast_possible_wrap, reason = "array index won't exceed i64::MAX")] + Ok(Value::Int(arr.argmax()? as i64)) +} + +/// `numpy.reshape(a, shape)` — reshape an array (module-level wrapper). +fn call_reshape_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.reshape", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.reshape() requires 2 arguments"))?; + defer_drop!(arr_val, vm); + let shape_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.reshape() requires 2 arguments"))?; + defer_drop!(shape_val, vm); + + for extra in pos { + extra.drop_with_heap(vm); + } + + let shape = extract_shape_from_value(shape_val, "numpy.reshape", vm)?; + + let arr = ndarray_from_value(arr_val, "numpy.reshape", vm)?; + arr.reshape(shape, vm.heap) +} + +/// `numpy.resize(a, new_shape)` — repeat flattened input data into a new shape. +fn call_resize(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.resize", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.resize", 2, 0))?; + defer_drop!(arr_val, vm); + let shape_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.resize", 2, 1))?; + defer_drop!(shape_val, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.resize", 2, 3)); + } + + let arr = ndarray_from_value(arr_val, "numpy.resize", vm)?; + let shape = extract_shape_from_value(shape_val, "numpy.resize", vm)?; + let total = shape.iter().product::(); + check_array_alloc_size(total, vm.heap.tracker())?; + let data = if total == 0 { + Vec::new() + } else if arr.data().is_empty() { + return Err(SimpleException::new_msg(ExcType::ValueError, "cannot resize an empty array").into()); + } else { + (0..total) + .map(|index| arr.data()[index % arr.len()]) + .collect::>() + }; + let result = NdArray::new(data, shape, arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.transpose(a, axes=None)` — transpose an array (module-level wrapper). +fn call_transpose_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + call_permute_dims_named(vm, args, "numpy.transpose") +} + +/// `numpy.take(a, indices)` — gather flattened elements at integer indices. +/// +/// Monty supports the default flattened mode. The optional `axis`, `out`, and +/// `mode` arguments are outside the current ndarray subset and must be omitted +/// or passed as `None` for `axis`. +fn call_take_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.take", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.take", 2, 0))?; + defer_drop!(arr_val, vm); + let indices_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.take", 2, 1))?; + defer_drop!(indices_val, vm); + + if let Some(axis_val) = pos.next() { + defer_drop!(axis_val, vm); + if !matches!(axis_val, Value::None) { + return Err(ExcType::type_error("numpy.take() axis is not supported yet")); + } + } + for extra in pos { + extra.drop_with_heap(vm); + } + + let arr = ndarray_from_value(arr_val, "numpy.take", vm)?; + if let Value::Int(index) = indices_val { + let resolved = resolve_flat_index(*index, arr.len())?; + Ok(ndarray_element_to_value(&arr, arr.data()[resolved])) + } else { + let indices = ndarray_from_value(indices_val, "numpy.take", vm)?; + take_flat_indices(&arr, &indices, vm.heap) + } +} + +/// `numpy.take_along_axis(a, indices, axis)` — gather per-axis positions from an array. +fn call_take_along_axis(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.take_along_axis", 3, 0))?; + defer_drop!(arr_val, vm); + let indices_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.take_along_axis", 3, 1))?; + defer_drop!(indices_val, vm); + let axis_value = pos.next(); + defer_drop_mut!(axis_value, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most("numpy.take_along_axis", 3, 3 + pos.len())); + } + parse_axis_keyword(kwargs_iter, axis_value, "numpy.take_along_axis", vm)?; + let Some(axis_value) = axis_value.as_ref() else { + return Err(ExcType::type_error_at_least("numpy.take_along_axis", 3, 2)); + }; + + let arr = ndarray_from_value(arr_val, "numpy.take_along_axis", vm)?; + let indices = ndarray_from_value(indices_val, "numpy.take_along_axis", vm)?; + let axis = normalize_axis( + value_to_i64_arg(axis_value, "numpy.take_along_axis", "axis")?, + arr.ndim(), + "numpy.take_along_axis", + )?; + let result = take_along_axis_array(&arr, &indices, axis, "numpy.take_along_axis")?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.compress(condition, a)` — select flattened elements where condition is true. +/// +/// The optional `axis` and `out` arguments are not modeled yet. Omitting `axis` +/// matches the flattened behavior of NumPy and the existing ndarray method. +fn call_compress_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.compress", vm.heap)?; + defer_drop_mut!(pos, vm); + + let condition_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.compress", 2, 0))?; + defer_drop!(condition_val, vm); + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.compress", 2, 1))?; + defer_drop!(arr_val, vm); + + if let Some(axis_val) = pos.next() { + defer_drop!(axis_val, vm); + if !matches!(axis_val, Value::None) { + return Err(ExcType::type_error("numpy.compress() axis is not supported yet")); + } + } + for extra in pos { + extra.drop_with_heap(vm); + } + + let condition = ndarray_from_value(condition_val, "numpy.compress", vm)?; + let arr = ndarray_from_value(arr_val, "numpy.compress", vm)?; + arr.compress(&condition, vm.heap) +} + +/// `numpy.swapaxes(a, axis1, axis2)` — swap two axes of an array. +fn call_swapaxes_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.swapaxes", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.swapaxes", 3, 0))?; + defer_drop!(arr_val, vm); + let axis1_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.swapaxes", 3, 1))?; + defer_drop!(axis1_val, vm); + let axis2_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.swapaxes", 3, 2))?; + defer_drop!(axis2_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + + let arr = ndarray_from_value(arr_val, "numpy.swapaxes", vm)?; + let axis1 = normalize_axis( + value_to_i64_arg(axis1_val, "numpy.swapaxes", "axis1")?, + arr.ndim(), + "numpy.swapaxes", + )?; + let axis2 = normalize_axis( + value_to_i64_arg(axis2_val, "numpy.swapaxes", "axis2")?, + arr.ndim(), + "numpy.swapaxes", + )?; + let mut axes: Vec = (0..arr.ndim()).collect(); + axes.swap(axis1, axis2); + permute_ndarray_axes(&arr, &axes, vm.heap, "numpy.swapaxes") +} + +/// `numpy.permute_dims(a, axes=None)` — permute ndarray axes. +fn call_permute_dims(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + call_permute_dims_named(vm, args, "numpy.permute_dims") +} + +/// Shared implementation for `transpose` and `permute_dims`. +fn call_permute_dims_named( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + name: &'static str, +) -> RunResult { + let pos = args.into_pos_only(name, vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 1, 0))?; + defer_drop!(arr_val, vm); + let axes_val = pos.next(); + for extra in pos { + extra.drop_with_heap(vm); + } + + let arr = ndarray_from_value(arr_val, name, vm)?; + let axes = if let Some(axes_val) = axes_val { + defer_drop!(axes_val, vm); + axes_permutation_from_value(axes_val, arr.ndim(), name, vm)? + } else { + default_transpose_axes(arr.ndim()) + }; + permute_ndarray_axes(&arr, &axes, vm.heap, name) +} + +/// `numpy.matrix_transpose(a)` — swap the last two axes of an array. +fn call_matrix_transpose(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.matrix_transpose", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.matrix_transpose", vm)?; + let ndim = arr.ndim(); + if ndim < 2 { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("Input array must be at least 2-dimensional, but it is {ndim}"), + ) + .into()) + } else { + let mut axes: Vec = (0..ndim).collect(); + axes.swap(ndim - 2, ndim - 1); + permute_ndarray_axes(&arr, &axes, vm.heap, "numpy.matrix_transpose") + } +} + +/// `numpy.moveaxis(a, source, destination)` — move axes to new positions. +fn call_moveaxis(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.moveaxis", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.moveaxis", 3, 0))?; + defer_drop!(arr_val, vm); + let source_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.moveaxis", 3, 1))?; + defer_drop!(source_val, vm); + let destination_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.moveaxis", 3, 2))?; + defer_drop!(destination_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + + let arr = ndarray_from_value(arr_val, "numpy.moveaxis", vm)?; + let source = axis_list_from_value(source_val, arr.ndim(), "numpy.moveaxis", "source", vm)?; + let destination = axis_list_from_value(destination_val, arr.ndim(), "numpy.moveaxis", "destination", vm)?; + let axes = moveaxis_permutation(arr.ndim(), &source, &destination)?; + permute_ndarray_axes(&arr, &axes, vm.heap, "numpy.moveaxis") +} + +/// `numpy.rollaxis(a, axis, start=0)` — roll an axis backward to a target position. +fn call_rollaxis(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.rollaxis", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.rollaxis", 2, 0))?; + defer_drop!(arr_val, vm); + let axis_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.rollaxis", 2, 1))?; + defer_drop!(axis_val, vm); + let start_val = pos.next(); + for extra in pos { + extra.drop_with_heap(vm); + } + + let arr = ndarray_from_value(arr_val, "numpy.rollaxis", vm)?; + let axis = normalize_axis( + value_to_i64_arg(axis_val, "numpy.rollaxis", "axis")?, + arr.ndim(), + "numpy.rollaxis", + )?; + let start = if let Some(start_val) = start_val { + defer_drop!(start_val, vm); + normalize_rollaxis_start(value_to_i64_arg(start_val, "numpy.rollaxis", "start")?, arr.ndim())? + } else { + 0 + }; + let axes = rollaxis_permutation(arr.ndim(), axis, start); + permute_ndarray_axes(&arr, &axes, vm.heap, "numpy.rollaxis") +} + +/// `numpy.rot90(a, k=1)` — rotate a 2-D array by 90-degree increments. +fn call_rot90(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.rot90", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.rot90", 1, 0))?; + defer_drop!(arr_val, vm); + let k_val = pos.next(); + let axes_val = pos.next(); + for extra in pos { + extra.drop_with_heap(vm); + } + + let arr = ndarray_from_value(arr_val, "numpy.rot90", vm)?; + let k = if let Some(k_val) = k_val { + defer_drop!(k_val, vm); + value_to_i64_arg(k_val, "numpy.rot90", "k")? + } else { + 1 + }; + let axes = if let Some(axes_val) = axes_val { + defer_drop!(axes_val, vm); + axis_pair_from_value(axes_val, arr.ndim(), "numpy.rot90", "axes", vm)? + } else { + default_axis_pair(arr.ndim(), "numpy.rot90")? + }; + rot90_ndarray(&arr, k, axes, vm.heap) +} + +/// `numpy.choose(a, choices)` — choose values from a sequence by integer index array. +fn call_choose(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.choose", vm.heap)?; + defer_drop_mut!(pos, vm); + + let index_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.choose", 2, 0))?; + defer_drop!(index_val, vm); + let choices_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.choose", 2, 1))?; + defer_drop!(choices_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + + let indices = ndarray_from_value(index_val, "numpy.choose", vm)?; + let choice_items = sequence_items(choices_val, "numpy.choose", vm)?; + defer_drop!(choice_items, vm); + let result = choose_from_arrays(&indices, choice_items, "numpy.choose", vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Choice buffer used by `numpy.choose`. +struct ChoiceData { + /// Flat scalar or array values for one choice branch. + values: Vec, + /// Compact dtype for the choice branch. + dtype: NdArrayDtype, +} + +/// Builds the output array for `numpy.choose` from validated choice branches. +fn choose_from_arrays( + indices: &NdArray, + choice_items: &[Value], + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + if choice_items.is_empty() { + return Err(SimpleException::new_msg(ExcType::ValueError, "invalid entry in choice array").into()); + } + + let mut choices = Vec::with_capacity(choice_items.len()); + let mut dtype = NdArrayDtype::Bool; + for choice in choice_items { + let choice_data = choice_data_from_value(choice, indices.len(), name, vm)?; + dtype = if choices.is_empty() { + choice_data.dtype + } else { + promote_dtype(dtype, choice_data.dtype) + }; + choices.push(choice_data); + } + + let mut data = Vec::with_capacity(indices.len()); + for (offset, raw_index) in indices.data().iter().copied().enumerate() { + let choice_index = choice_index_from_f64(raw_index, choices.len())?; + data.push(broadcast_value_at(&choices[choice_index].values, offset)); + } + Ok(NdArray::new(data, indices.shape().to_vec(), dtype)) +} + +/// Converts one `choose` branch into a scalar or index-shaped value buffer. +fn choice_data_from_value( + value: &Value, + output_len: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + if let Ok((scalar, dtype)) = numeric_scalar_info(value, name, vm) { + Ok(ChoiceData { + values: vec![scalar], + dtype, + }) + } else { + let arr = ndarray_from_value(value, name, vm)?; + validate_broadcast_values(arr.data(), output_len, name)?; + Ok(ChoiceData { + values: arr.data().to_vec(), + dtype: arr.dtype(), + }) + } +} + +/// Converts a numeric `choose` selector into a branch index. +fn choice_index_from_f64(value: f64, choice_count: usize) -> RunResult { + #[expect(clippy::cast_possible_truncation, reason = "choice index from numeric ndarray")] + let index = value as i64; + if index < 0 || usize::try_from(index).map_or(true, |index| index >= choice_count) { + Err(SimpleException::new_msg(ExcType::ValueError, "invalid entry in choice array").into()) + } else { + usize::try_from(index) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "choice index is too large").into()) + } +} + +/// Returns the default transpose permutation, which reverses axis order. +fn default_transpose_axes(ndim: usize) -> Vec { + (0..ndim).rev().collect() +} + +/// Parses a full axis permutation for `transpose`-style calls. +fn axes_permutation_from_value( + value: &Value, + ndim: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + if matches!(value, Value::None) { + Ok(default_transpose_axes(ndim)) + } else { + let axes = axis_sequence_from_value(value, ndim, name, "axes", vm)?; + ensure_axes_are_permutation(&axes, ndim, name)?; + Ok(axes) + } +} + +/// Parses a list or tuple of axes without accepting scalar shorthand. +fn axis_sequence_from_value( + value: &Value, + ndim: usize, + name: &str, + arg_name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + match value { + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(list) => axis_sequence_from_items(list.as_slice(), ndim, name, arg_name), + HeapData::Tuple(tuple) => axis_sequence_from_items(tuple.as_slice(), ndim, name, arg_name), + _ => Err(ExcType::type_error(format!( + "{name}() {arg_name} must be a tuple or list of integers" + ))), + }, + _ => Err(ExcType::type_error(format!( + "{name}() {arg_name} must be a tuple or list of integers" + ))), + } +} + +/// Parses either a scalar axis or a list/tuple of axes. +fn axis_list_from_value( + value: &Value, + ndim: usize, + name: &str, + arg_name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + match value { + Value::Int(axis) => Ok(vec![normalize_axis(*axis, ndim, name)?]), + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(list) => { + let axes = axis_sequence_from_items(list.as_slice(), ndim, name, arg_name)?; + ensure_unique_axes(&axes, name)?; + Ok(axes) + } + HeapData::Tuple(tuple) => { + let axes = axis_sequence_from_items(tuple.as_slice(), ndim, name, arg_name)?; + ensure_unique_axes(&axes, name)?; + Ok(axes) + } + _ => Err(ExcType::type_error(format!( + "{name}() {arg_name} must be an integer or tuple of integers" + ))), + }, + _ => Err(ExcType::type_error(format!( + "{name}() {arg_name} must be an integer or tuple of integers" + ))), + } +} + +/// Converts a sequence of axis values into normalized axis indices. +fn axis_sequence_from_items(items: &[Value], ndim: usize, name: &str, arg_name: &str) -> RunResult> { + items + .iter() + .map(|item| value_to_i64_arg(item, name, arg_name).and_then(|axis| normalize_axis(axis, ndim, name))) + .collect() +} + +/// Parses the two-axis tuple used by `rot90`. +fn axis_pair_from_value( + value: &Value, + ndim: usize, + name: &str, + arg_name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<[usize; 2]> { + let axes = axis_sequence_from_value(value, ndim, name, arg_name, vm)?; + if axes.len() != 2 { + Err(ExcType::type_error(format!( + "{name}() {arg_name} must contain exactly two axes" + ))) + } else if axes[0] == axes[1] { + Err(SimpleException::new_msg(ExcType::ValueError, "Axes must be different.").into()) + } else { + Ok([axes[0], axes[1]]) + } +} + +/// Returns the default `rot90` axes, validating that the array is at least 2-D. +fn default_axis_pair(ndim: usize, name: &str) -> RunResult<[usize; 2]> { + if ndim < 2 { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() requires an array of at least two dimensions"), + ) + .into()) + } else { + Ok([0, 1]) + } +} + +/// Normalizes a possibly negative axis into a valid dimension index. +fn normalize_axis(axis: i64, ndim: usize, name: &str) -> RunResult { + let ndim_i64 = i64::try_from(ndim) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, format!("{name}() ndim is too large")))?; + let normalized = if axis < 0 { axis + ndim_i64 } else { axis }; + if normalized < 0 || normalized >= ndim_i64 { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("bad axis for array with {ndim} dimensions"), + ) + .into()) + } else { + usize::try_from(normalized) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, format!("{name}() axis is too large")).into()) + } +} + +/// Normalizes `rollaxis(start)`, whose insertion point may be equal to `ndim`. +fn normalize_rollaxis_start(start: i64, ndim: usize) -> RunResult { + let ndim_i64 = i64::try_from(ndim) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "numpy.rollaxis() ndim is too large"))?; + let normalized = if start < 0 { start + ndim_i64 } else { start }; + if normalized < 0 || normalized > ndim_i64 { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("bad axis for array with {ndim} dimensions"), + ) + .into()) + } else { + usize::try_from(normalized) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "numpy.rollaxis() start is too large").into()) + } +} + +/// Validates that an axis list has no duplicates. +fn ensure_unique_axes(axes: &[usize], name: &str) -> RunResult<()> { + for (index, axis) in axes.iter().enumerate() { + if axes[..index].contains(axis) { + return Err(SimpleException::new_msg(ExcType::ValueError, format!("{name}() repeated axis")).into()); + } + } + Ok(()) +} + +/// Validates that an axis list contains each axis exactly once. +fn ensure_axes_are_permutation(axes: &[usize], ndim: usize, name: &str) -> RunResult<()> { + if axes.len() == ndim { + ensure_unique_axes(axes, name) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, format!("{name}() axes don't match array")).into()) + } +} + +/// Builds the axis order used by `moveaxis`. +fn moveaxis_permutation(ndim: usize, source: &[usize], destination: &[usize]) -> RunResult> { + if source.len() == destination.len() { + let mut axes: Vec = (0..ndim).filter(|axis| !source.contains(axis)).collect(); + let mut moves: Vec<(usize, usize)> = destination.iter().copied().zip(source.iter().copied()).collect(); + moves.sort_by_key(|(dest, _)| *dest); + for (dest, src) in moves { + axes.insert(dest, src); + } + Ok(axes) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + "numpy.moveaxis() source and destination arguments must have the same number of elements", + ) + .into()) + } +} + +/// Builds the axis order used by `rollaxis`. +fn rollaxis_permutation(ndim: usize, axis: usize, start: usize) -> Vec { + let mut insert_at = start; + if axis < insert_at { + insert_at -= 1; + } + let mut axes: Vec = (0..ndim).collect(); + axes.remove(axis); + axes.insert(insert_at, axis); + axes +} + +/// Allocates an ndarray with axes permuted according to NumPy row-major order. +fn permute_ndarray_axes( + arr: &NdArray, + axes: &[usize], + heap: &Heap, + name: &str, +) -> RunResult { + ensure_axes_are_permutation(axes, arr.ndim(), name)?; + let new_shape: Vec = axes.iter().map(|&axis| arr.shape()[axis]).collect(); + if axes.iter().copied().eq(0..arr.ndim()) || arr.ndim() <= 1 { + let result = NdArray::new(arr.data().to_vec(), new_shape, arr.dtype()); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } else { + let old_strides = row_major_strides(arr.shape()); + let new_strides = row_major_strides(&new_shape); + let mut data = vec![0.0; arr.len()]; + for (old_flat, value) in arr.data().iter().copied().enumerate() { + let old_coords = coords_from_flat_index(old_flat, arr.shape(), &old_strides); + let new_flat = axes + .iter() + .enumerate() + .map(|(new_axis, &old_axis)| old_coords[old_axis] * new_strides[new_axis]) + .sum::(); + data[new_flat] = value; + } + let result = NdArray::new(data, new_shape, arr.dtype()); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// Computes row-major strides for a shape. +fn row_major_strides(shape: &[usize]) -> Vec { + let mut strides = vec![1; shape.len()]; + let mut stride = 1usize; + for axis in (0..shape.len()).rev() { + strides[axis] = stride; + stride = stride.saturating_mul(shape[axis]); + } + strides +} + +/// Converts a flat row-major index into coordinate components. +fn coords_from_flat_index(flat: usize, shape: &[usize], strides: &[usize]) -> Vec { + shape + .iter() + .zip(strides.iter()) + .map(|(&dim, &stride)| if dim == 0 { 0 } else { (flat / stride) % dim }) + .collect() +} + +/// Parses an optional `axis` keyword into the shared axis value slot. +fn parse_axis_keyword( + kwargs_iter: &mut impl Iterator, + axis_value: &mut Option, + name: &str, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult<()> { + for (key, value) in kwargs_iter { + defer_drop!(key, vm); + let Some(keyword_name) = key.as_either_str(vm.heap) else { + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + if key_str == "axis" { + if axis_value.is_some() { + value.drop_with_heap(vm); + return Err(ExcType::type_error_duplicate_arg(name, key_str)); + } + *axis_value = Some(value); + } else { + value.drop_with_heap(vm); + return Err(ExcType::type_error_unexpected_keyword(name, key_str)); + } + } + Ok(()) +} + +/// Implements flattened `take` while preserving the shape of the indices array. +fn take_flat_indices(arr: &NdArray, indices: &NdArray, heap: &Heap) -> RunResult { + let mut data = Vec::with_capacity(indices.len()); + for index in indices.data().iter().copied() { + #[expect(clippy::cast_possible_truncation, reason = "index from numeric ndarray")] + let resolved = resolve_flat_index(index as i64, arr.len())?; + data.push(arr.data()[resolved]); + } + let result = NdArray::new(data, indices.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) +} + +/// Implements `take_along_axis` by resolving every indexed output coordinate. +fn take_along_axis_array(arr: &NdArray, indices: &NdArray, axis: usize, name: &str) -> RunResult { + let targets = along_axis_flat_indices(arr.shape(), indices, axis, name)?; + let data = targets.into_iter().map(|target| arr.data()[target]).collect::>(); + Ok(NdArray::new(data, indices.shape().to_vec(), arr.dtype())) +} + +/// Resolves every `indices` entry into a flat row-major index for an array shape. +fn along_axis_flat_indices(arr_shape: &[usize], indices: &NdArray, axis: usize, name: &str) -> RunResult> { + validate_along_axis_shapes(arr_shape, indices.shape(), axis, name)?; + let arr_strides = row_major_strides(arr_shape); + let index_strides = row_major_strides(indices.shape()); + let mut targets = Vec::with_capacity(indices.len()); + for (flat, raw_index) in indices.data().iter().copied().enumerate() { + let mut coords = coords_from_flat_index(flat, indices.shape(), &index_strides); + #[expect(clippy::cast_possible_truncation, reason = "axis index from numeric ndarray")] + { + coords[axis] = resolve_flat_index(raw_index as i64, arr_shape[axis])?; + } + let target = coords + .iter() + .zip(arr_strides.iter()) + .map(|(coord, stride)| coord * stride) + .sum::(); + targets.push(target); + } + Ok(targets) +} + +/// Validates the shared-dimensional shape rule used by NumPy's along-axis helpers. +fn validate_along_axis_shapes(arr_shape: &[usize], index_shape: &[usize], axis: usize, name: &str) -> RunResult<()> { + if arr_shape.len() != index_shape.len() { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() indices and arr must have the same number of dimensions"), + ) + .into()) + } else if arr_shape + .iter() + .zip(index_shape.iter()) + .enumerate() + .any(|(dim, (arr_dim, index_dim))| dim != axis && arr_dim != index_dim) + { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() shape mismatch outside the indexed axis"), + ) + .into()) + } else { + Ok(()) + } +} + +/// Resolves a possibly negative flattened index. +fn resolve_flat_index(index: i64, len: usize) -> RunResult { + let len_i64 = + i64::try_from(len).map_err(|_| SimpleException::new_msg(ExcType::ValueError, "array is too large"))?; + let resolved = if index < 0 { index + len_i64 } else { index }; + if resolved < 0 || resolved >= len_i64 { + Err(SimpleException::new_msg(ExcType::IndexError, "index out of range").into()) + } else { + usize::try_from(resolved) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "index is too large").into()) + } +} + +/// Converts a raw ndarray element back to the public scalar Value for its dtype. +fn ndarray_element_to_value(arr: &NdArray, value: f64) -> Value { + match arr.dtype() { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is the intended int conversion" + )] + NdArrayDtype::Int64 => Value::Int(value as i64), + NdArrayDtype::Float64 => Value::Float(value), + NdArrayDtype::Bool => Value::Bool(value != 0.0), + } +} + +/// Rotates a 2-D ndarray by `k` quarter turns across the requested axis pair. +fn rot90_ndarray(arr: &NdArray, k: i64, axes: [usize; 2], heap: &Heap) -> RunResult { + if arr.ndim() != 2 { + Err(SimpleException::new_msg(ExcType::ValueError, "numpy.rot90() only supports 2-D arrays").into()) + } else if axes != [0, 1] && axes != [1, 0] { + Err(SimpleException::new_msg(ExcType::ValueError, "numpy.rot90() only supports axes (0, 1)").into()) + } else { + let adjusted_k = if axes == [1, 0] { -k } else { k }; + let k = adjusted_k.rem_euclid(4); + let rows = arr.shape()[0]; + let cols = arr.shape()[1]; + let (data, shape) = match k { + 0 => (arr.data().to_vec(), arr.shape().to_vec()), + 1 => { + let mut data = Vec::with_capacity(arr.len()); + for col in (0..cols).rev() { + for row in 0..rows { + data.push(arr.data()[row * cols + col]); + } + } + (data, vec![cols, rows]) + } + 2 => { + let mut data = arr.data().to_vec(); + data.reverse(); + (data, arr.shape().to_vec()) + } + _ => { + let mut data = Vec::with_capacity(arr.len()); + for col in 0..cols { + for row in (0..rows).rev() { + data.push(arr.data()[row * cols + col]); + } + } + (data, vec![cols, rows]) + } + }; + let result = NdArray::new(data, shape, arr.dtype()); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// `numpy.fill_diagonal(a, val)` — fill an ndarray diagonal in place. +fn call_fill_diagonal(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.fill_diagonal", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.fill_diagonal", 2, 0))?; + defer_drop!(arr_val, vm); + let fill_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.fill_diagonal", 2, 1))?; + defer_drop!(fill_val, vm); + let wrap = if let Some(wrap_val) = pos.next() { + defer_drop!(wrap_val, vm); + bool_scalar_from_value(wrap_val, "numpy.fill_diagonal", "wrap")? + } else { + false + }; + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.fill_diagonal", 3, 4)); + } + + let arr_id = mutable_ndarray_id(arr_val, "numpy.fill_diagonal", vm)?; + let values = mutation_values_from_value(fill_val, "numpy.fill_diagonal", vm)?; + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(arr_id) else { + unreachable!() + }; + let arr = arr_read.get_mut(vm.heap); + let indices = fill_diagonal_flat_indices(arr.shape(), wrap)?; + assign_cycled_values(&mut arr.data, &indices, &values)?; + drop(arr_read); + Ok(Value::None) +} + +/// `numpy.put(a, ind, v)` — assign flattened positions in place. +fn call_put(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.put", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.put", 3, 0))?; + defer_drop!(arr_val, vm); + let indices_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.put", 3, 1))?; + defer_drop!(indices_val, vm); + let values_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.put", 3, 2))?; + defer_drop!(values_val, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.put", 3, 4)); + } + + let arr_id = mutable_ndarray_id(arr_val, "numpy.put", vm)?; + let len = ndarray_len_by_id(arr_id, "numpy.put", vm)?; + let indices = flat_indices_from_value(indices_val, len, "numpy.put", vm)?; + let values = mutation_values_from_value(values_val, "numpy.put", vm)?; + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(arr_id) else { + unreachable!() + }; + assign_cycled_values(&mut arr_read.get_mut(vm.heap).data, &indices, &values)?; + drop(arr_read); + Ok(Value::None) +} + +/// `numpy.put_along_axis(a, indices, values, axis)` — assign values along an axis in place. +fn call_put_along_axis(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.put_along_axis", 4, 0))?; + defer_drop!(arr_val, vm); + let indices_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.put_along_axis", 4, 1))?; + defer_drop!(indices_val, vm); + let values_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.put_along_axis", 4, 2))?; + defer_drop!(values_val, vm); + let axis_value = pos.next(); + defer_drop_mut!(axis_value, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most("numpy.put_along_axis", 4, 4 + pos.len())); + } + parse_axis_keyword(kwargs_iter, axis_value, "numpy.put_along_axis", vm)?; + let Some(axis_value) = axis_value.as_ref() else { + return Err(ExcType::type_error_at_least("numpy.put_along_axis", 4, 3)); + }; + + let arr_id = mutable_ndarray_id(arr_val, "numpy.put_along_axis", vm)?; + let indices = ndarray_from_value(indices_val, "numpy.put_along_axis", vm)?; + let values = mutation_values_from_value(values_val, "numpy.put_along_axis", vm)?; + validate_broadcast_values(&values, indices.len(), "numpy.put_along_axis")?; + let targets = { + let HeapData::NdArray(arr) = vm.heap.get(arr_id) else { + unreachable!() + }; + let axis = normalize_axis( + value_to_i64_arg(axis_value, "numpy.put_along_axis", "axis")?, + arr.ndim(), + "numpy.put_along_axis", + )?; + along_axis_flat_indices(arr.shape(), &indices, axis, "numpy.put_along_axis")? + }; + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(arr_id) else { + unreachable!() + }; + let arr = arr_read.get_mut(vm.heap); + for (index, target) in targets.into_iter().enumerate() { + arr.data[target] = broadcast_value_at(&values, index); + } + drop(arr_read); + Ok(Value::None) +} + +/// `numpy.copyto(dst, src, where=True)` — copy values into an ndarray in place. +fn call_copyto(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let dst_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.copyto", 2, 0))?; + defer_drop!(dst_val, vm); + let src_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.copyto", 2, 1))?; + defer_drop!(src_val, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most("numpy.copyto", 2, 3)); + } + + let where_value = None; + defer_drop_mut!(where_value, vm); + for (key, value) in kwargs_iter { + defer_drop!(key, vm); + let Some(keyword_name) = key.as_either_str(vm.heap) else { + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + if let Some(StaticStrings::NpWhere) = keyword_name.static_string() { + if where_value.is_some() { + value.drop_with_heap(vm); + return Err(ExcType::type_error_duplicate_arg( + "copyto", + keyword_name.as_str(vm.interns), + )); + } + *where_value = Some(value); + } else { + value.drop_with_heap(vm); + return Err(ExcType::type_error_unexpected_keyword( + "copyto", + keyword_name.as_str(vm.interns), + )); + } + } + + let arr_id = mutable_ndarray_id(dst_val, "numpy.copyto", vm)?; + let len = ndarray_len_by_id(arr_id, "numpy.copyto", vm)?; + let source = mutation_values_from_value(src_val, "numpy.copyto", vm)?; + validate_broadcast_values(&source, len, "numpy.copyto")?; + let where_mask = if let Some(value) = where_value.as_ref() { + Some(bool_mask_from_value(value, len, "numpy.copyto", true, vm)?) + } else { + None + }; + + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(arr_id) else { + unreachable!() + }; + let arr = arr_read.get_mut(vm.heap); + for (index, slot) in arr.data.iter_mut().enumerate() { + if where_mask.as_ref().is_none_or(|mask| mask[index]) { + *slot = broadcast_value_at(&source, index); + } + } + drop(arr_read); + Ok(Value::None) +} + +/// `numpy.putmask(a, mask, values)` — assign by flat mask positions in place. +fn call_putmask(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_id, mask, values) = masked_mutation_args(args, "numpy.putmask", false, vm)?; + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(arr_id) else { + unreachable!() + }; + let arr = arr_read.get_mut(vm.heap); + for (index, slot) in arr.data.iter_mut().enumerate() { + if mask[index] { + *slot = values[index % values.len()]; + } + } + drop(arr_read); + Ok(Value::None) +} + +/// `numpy.place(a, mask, values)` — place values sequentially where a mask is true. +fn call_place(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_id, mask, values) = masked_mutation_args(args, "numpy.place", false, vm)?; + let HeapReadOutput::NdArray(mut arr_read) = vm.heap.read(arr_id) else { + unreachable!() + }; + let arr = arr_read.get_mut(vm.heap); + let mut value_index = 0usize; + for (index, slot) in arr.data.iter_mut().enumerate() { + if mask[index] { + *slot = values[value_index % values.len()]; + value_index += 1; + } + } + drop(arr_read); + Ok(Value::None) +} + +/// Parses the shared `(a, mask, values)` arguments for masked mutation helpers. +fn masked_mutation_args( + args: ArgValues, + name: &str, + allow_scalar_mask: bool, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult<(HeapId, Vec, Vec)> { + let pos = args.into_pos_only(name, vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 3, 0))?; + defer_drop!(arr_val, vm); + let mask_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 3, 1))?; + defer_drop!(mask_val, vm); + let values_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 3, 2))?; + defer_drop!(values_val, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most(name, 3, 4)); + } + + let arr_id = mutable_ndarray_id(arr_val, name, vm)?; + let len = ndarray_len_by_id(arr_id, name, vm)?; + let mask = bool_mask_from_value(mask_val, len, name, allow_scalar_mask, vm)?; + let values = mutation_values_from_value(values_val, name, vm)?; + ensure_nonempty_values(&values, name)?; + Ok((arr_id, mask, values)) +} + +/// Returns a mutable ndarray heap id after validating that the target is an ndarray. +fn mutable_ndarray_id(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::Ref(id) if matches!(vm.heap.get(*id), HeapData::NdArray(_)) => Ok(*id), + _ => Err(ExcType::type_error(format!("{name}() target must be an ndarray"))), + } +} + +/// Returns the flat length for a validated ndarray heap id. +fn ndarray_len_by_id(id: HeapId, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match vm.heap.get(id) { + HeapData::NdArray(arr) => Ok(arr.len()), + _ => Err(ExcType::type_error(format!("{name}() target must be an ndarray"))), + } +} + +/// Converts one scalar or array-like mutation value into flat f64 storage. +fn mutation_values_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult> { + if let Ok((scalar, _)) = numeric_scalar_info(value, name, vm) { + Ok(vec![scalar]) + } else { + let arr = ndarray_from_value(value, name, vm)?; + ensure_nonempty_values(arr.data(), name)?; + Ok(arr.data().to_vec()) + } +} + +/// Rejects empty mutation value arrays, which cannot supply cycled assignments. +fn ensure_nonempty_values(values: &[f64], name: &str) -> RunResult<()> { + if values.is_empty() { + Err(SimpleException::new_msg(ExcType::ValueError, format!("{name}() values must not be empty")).into()) + } else { + Ok(()) + } +} + +/// Converts an integer or array-like value into resolved flattened indices. +fn flat_indices_from_value( + value: &Value, + len: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + if let Value::Int(index) = value { + Ok(vec![resolve_flat_index(*index, len)?]) + } else { + let indices = ndarray_from_value(value, name, vm)?; + indices + .data() + .iter() + .map(|&index| { + #[expect(clippy::cast_possible_truncation, reason = "index from numeric ndarray")] + { + resolve_flat_index(index as i64, len) + } + }) + .collect() + } +} + +/// Converts a bool-like scalar or array-like value into a flat mask. +fn bool_mask_from_value( + value: &Value, + len: usize, + name: &str, + allow_scalar: bool, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + match value { + Value::Bool(value) if allow_scalar => Ok(vec![*value; len]), + Value::Int(value) if allow_scalar => Ok(vec![*value != 0; len]), + Value::Float(value) if allow_scalar => Ok(vec![*value != 0.0; len]), + _ => { + let mask = ndarray_from_value(value, name, vm)?; + if mask.len() == len { + Ok(mask.data().iter().map(|&value| value != 0.0).collect()) + } else { + Err( + SimpleException::new_msg(ExcType::ValueError, format!("{name}() mask must match array size")) + .into(), + ) + } + } + } +} + +/// Converts a bool-like scalar argument. +fn bool_scalar_from_value(value: &Value, name: &str, arg_name: &str) -> RunResult { + match value { + Value::Bool(value) => Ok(*value), + Value::Int(value) => Ok(*value != 0), + Value::Float(value) => Ok(*value != 0.0), + _ => Err(ExcType::type_error(format!("{name}() {arg_name} must be a boolean"))), + } +} + +/// Computes the flat row-major offsets that participate in `fill_diagonal`. +fn fill_diagonal_flat_indices(shape: &[usize], wrap: bool) -> RunResult> { + if shape.len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "array must be at least 2-d").into()); + } + if shape.len() == 2 { + let rows = shape[0]; + let cols = shape[1]; + if wrap { + let total = rows.saturating_mul(cols); + let step = cols.saturating_add(1); + Ok((0..total).step_by(step.max(1)).collect()) + } else { + Ok((0..rows.min(cols)).map(|index| index * cols + index).collect()) + } + } else if shape.iter().all(|&dim| dim == shape[0]) { + let diagonal_stride = row_major_strides(shape).iter().sum::(); + Ok((0..shape[0]).map(|index| index * diagonal_stride).collect()) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "All dimensions of input must be of equal length").into()) + } +} + +/// Assigns cycled values into pre-resolved flat positions. +fn assign_cycled_values(target: &mut [f64], indices: &[usize], values: &[f64]) -> RunResult<()> { + ensure_nonempty_values(values, "numpy assignment")?; + for (value_index, &target_index) in indices.iter().enumerate() { + target[target_index] = values[value_index % values.len()]; + } + Ok(()) +} + +/// Validates source values for `copyto` scalar or equal-size broadcasting. +fn validate_broadcast_values(values: &[f64], len: usize, name: &str) -> RunResult<()> { + ensure_nonempty_values(values, name)?; + if values.len() == 1 || values.len() == len { + Ok(()) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() source must be scalar or same size"), + ) + .into()) + } +} + +/// Returns the value for a scalar-broadcast or same-size source buffer. +fn broadcast_value_at(values: &[f64], index: usize) -> f64 { + values[if values.len() == 1 { 0 } else { index }] +} + +/// `numpy.append(a, values)` — append values to end of array (flattened). +fn call_append(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.append", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + + let a_arr = ndarray_from_value(a_val, "numpy.append", vm)?; + let b_arr = ndarray_from_value(b_val, "numpy.append", vm)?; + + let mut combined = a_arr.data().to_vec(); + combined.extend_from_slice(b_arr.data()); + let len = combined.len(); + check_array_alloc_size(len, vm.heap.tracker())?; + let result_dtype = promote_dtype(a_arr.dtype(), b_arr.dtype()); + let arr = NdArray::new(combined, vec![len], result_dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.vstack(arrays)` / `numpy.stack(arrays)` — stack 1D arrays as rows of a 2D array. +fn call_vstack(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.vstack", vm.heap)?; + defer_drop!(arg, vm); + + let Value::Ref(list_id) = arg else { + return Err(ExcType::type_error("numpy.vstack() requires a list of arrays")); + }; + let arr_ids: Vec = { + let HeapData::List(list) = vm.heap.get(*list_id) else { + return Err(ExcType::type_error("numpy.vstack() requires a list of arrays")); + }; + list.as_slice() + .iter() + .map(|v| match v { + Value::Ref(id) => Ok(*id), + _ => Err(ExcType::type_error( + "numpy.vstack() requires all elements to be ndarrays", + )), + }) + .collect::>>()? + }; + + if arr_ids.is_empty() { + return Err(SimpleException::new_msg(ExcType::ValueError, "need at least one array to stack").into()); + } + + // Get the column count from the first array. + let HeapData::NdArray(first) = vm.heap.get(arr_ids[0]) else { + return Err(ExcType::type_error("numpy.vstack() requires ndarrays")); + }; + let cols = if first.ndim() == 1 { + first.len() + } else { + first.shape()[1] + }; + let mut result_dtype = first.dtype(); + + let mut combined = Vec::new(); + for &arr_id in &arr_ids { + let HeapData::NdArray(arr) = vm.heap.get(arr_id) else { + return Err(ExcType::type_error("numpy.vstack() requires ndarrays")); + }; + let arr_cols = if arr.ndim() == 1 { arr.len() } else { arr.shape()[1] }; + if arr_cols != cols { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "all input arrays must have the same number of columns", + ) + .into()); + } + combined.extend_from_slice(arr.data()); + result_dtype = promote_dtype(result_dtype, arr.dtype()); + } + + let rows = combined.len() / cols; + check_array_alloc_size(combined.len(), vm.heap.tracker())?; + let arr = NdArray::new(combined, vec![rows, cols], result_dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.hstack(arrays)` — concatenate arrays horizontally. +/// +/// For 1D arrays, hstack is equivalent to concatenate (the LLM-common case). +/// For 2D+ arrays, hstack should concatenate along axis=1 — this is not yet +/// implemented and will incorrectly concatenate along axis=0 instead. +fn call_hstack(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + call_concatenate(vm, args) +} + +/// `numpy.dstack(arrays)` — stack arrays along the third axis. +fn call_dstack(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.dstack", vm.heap)?; + defer_drop!(arg, vm); + let items = sequence_items(arg, "numpy.dstack", vm)?; + defer_drop_mut!(items, vm); + if items.is_empty() { + return Err(SimpleException::new_msg(ExcType::ValueError, "need at least one array to stack").into()); + } + + let mut arrays = Vec::with_capacity(items.len()); + for item in items.iter() { + let arr = ndarray_from_value(item, "numpy.dstack", vm)?; + arrays.push(dstack_promoted_array(arr)); + } + let result = concatenate_ndarrays_along_axis(&arrays, 2, "numpy.dstack", vm.heap.tracker())?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.block(arrays)` — assemble nested numeric arrays and scalars. +fn call_block(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.block", vm.heap)?; + defer_drop!(arg, vm); + + let analysis = block_layout_analysis(arg, "arrays".to_string(), vm)?; + let target_ndim = analysis.list_depth.max(analysis.max_leaf_ndim); + let result = block_array_from_value(arg, &analysis, target_ndim, 0, vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Shape summary for NumPy's recursive `block()` layout. +struct BlockAnalysis { + /// Number of list levels used to arrange blocks. + list_depth: usize, + /// Largest ndarray rank among leaf values. + max_leaf_ndim: usize, + /// Path to the first leaf, used in mismatch diagnostics. + first_leaf_path: String, +} + +/// Analyzes `block()` list nesting and leaf ranks before allocating output data. +fn block_layout_analysis(value: &Value, path: String, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => block_list_analysis(items.as_slice(), &path, vm), + HeapData::Tuple(_) => Err(block_tuple_type_error(&path)), + HeapData::NdArray(arr) => Ok(BlockAnalysis { + list_depth: 0, + max_leaf_ndim: arr.ndim(), + first_leaf_path: path, + }), + _ => block_scalar_analysis(value, path, vm), + }, + _ => block_scalar_analysis(value, path, vm), + } +} + +/// Analyzes one list level and rejects ragged nesting before assembly starts. +fn block_list_analysis(items: &[Value], path: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + if items.is_empty() { + return Err(SimpleException::new_msg(ExcType::ValueError, format!("List at {path} cannot be empty")).into()); + } + + let first = block_layout_analysis(&items[0], block_child_path(path, 0), vm)?; + let expected_depth = first.list_depth + 1; + let mut max_leaf_ndim = first.max_leaf_ndim; + let first_leaf_path = first.first_leaf_path; + for (index, item) in items.iter().enumerate().skip(1) { + let child = block_layout_analysis(item, block_child_path(path, index), vm)?; + let found_depth = child.list_depth + 1; + if found_depth != expected_depth { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!( + "List depths are mismatched. First element was at depth {expected_depth}, but there is an element at depth {found_depth} ({})", + child.first_leaf_path + ), + ) + .into()); + } + max_leaf_ndim = max_leaf_ndim.max(child.max_leaf_ndim); + } + + Ok(BlockAnalysis { + list_depth: expected_depth, + max_leaf_ndim, + first_leaf_path, + }) +} + +/// Treats a non-list `block()` leaf as a numeric scalar. +fn block_scalar_analysis(value: &Value, path: String, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + numeric_scalar_info(value, "numpy.block", vm)?; + Ok(BlockAnalysis { + list_depth: 0, + max_leaf_ndim: 0, + first_leaf_path: path, + }) +} + +/// Recursively assembles a `block()` value once its nesting has been validated. +fn block_array_from_value( + value: &Value, + analysis: &BlockAnalysis, + target_ndim: usize, + current_depth: usize, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + match value { + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => block_array_from_items(items.as_slice(), analysis, target_ndim, current_depth, vm), + HeapData::Tuple(_) => Err(block_tuple_type_error("arrays")), + _ => block_leaf_array(value, target_ndim, vm), + }, + _ => block_leaf_array(value, target_ndim, vm), + } +} + +/// Assembles one list level by concatenating child blocks along NumPy's axis for that depth. +fn block_array_from_items( + items: &[Value], + analysis: &BlockAnalysis, + target_ndim: usize, + current_depth: usize, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + let mut arrays = Vec::with_capacity(items.len()); + for item in items { + arrays.push(block_array_from_value( + item, + analysis, + target_ndim, + current_depth + 1, + vm, + )?); + } + let axis = target_ndim - (analysis.list_depth - current_depth); + concatenate_ndarrays_along_axis(&arrays, axis, "numpy.block", vm.heap.tracker()) +} + +/// Converts one `block()` leaf and prepends singleton axes when list depth requires it. +fn block_leaf_array(value: &Value, target_ndim: usize, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let arr = ndarray_or_scalar_from_value(value, "numpy.block", vm)?; + if arr.ndim() >= target_ndim { + Ok(arr) + } else { + let NdArray { data, shape, dtype, .. } = arr; + let mut promoted_shape = vec![1; target_ndim - shape.len()]; + promoted_shape.extend(shape); + Ok(NdArray::new(data, promoted_shape, dtype)) + } +} + +/// Builds the path NumPy displays in `block()` nesting diagnostics. +fn block_child_path(path: &str, index: usize) -> String { + format!("{path}[{index}]") +} + +/// Builds NumPy's tuple-specific `block()` error for unsupported implicit conversion. +fn block_tuple_type_error(path: &str) -> RunError { + ExcType::type_error(format!( + "{path} is a tuple. Only lists can be used to arrange blocks, and np.block does not allow implicit conversion from tuple to ndarray." + )) +} + +/// `numpy.unstack(a, axis=0)` — split an array into a tuple with one axis removed. +fn call_unstack(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.unstack", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.unstack", 1, 0))?; + defer_drop!(arr_val, vm); + let axis = if let Some(axis_val) = pos.next() { + defer_drop!(axis_val, vm); + value_to_i64_arg(axis_val, "numpy.unstack", "axis")? + } else { + 0 + }; + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.unstack", 2, 3)); + } + + let arr = ndarray_from_value(arr_val, "numpy.unstack", vm)?; + let axis = normalize_axis(axis, arr.ndim(), "numpy.unstack")?; + let result_shape = shape_without_axis(arr.shape(), axis); + let mut values: SmallVec<[Value; 3]> = SmallVec::new(); + for index in 0..arr.shape()[axis] { + let data = slice_ndarray_along_axis(&arr, axis, index, index + 1); + if result_shape.is_empty() { + values.push(scalar_from_f64(data[0], arr.dtype())); + } else { + let result = NdArray::new(data, result_shape.clone(), arr.dtype()); + values.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + } + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// Reshapes one dstack input according to NumPy's `atleast_3d` promotion rules. +fn dstack_promoted_array(arr: NdArray) -> NdArray { + let NdArray { data, shape, dtype, .. } = arr; + let shape = match shape.as_slice() { + [len] => vec![1, *len, 1], + [rows, cols] => vec![*rows, *cols, 1], + _ => shape, + }; + NdArray::new(data, shape, dtype) +} + +/// Concatenates arrays along one axis, preserving row-major layout. +fn concatenate_ndarrays_along_axis( + arrays: &[NdArray], + axis: usize, + name: &str, + tracker: &impl ResourceTracker, +) -> RunResult { + let first = arrays + .first() + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, format!("{name}() needs at least one array")))?; + let mut output_shape = first.shape().to_vec(); + if axis >= output_shape.len() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!("bad axis for array with {} dimensions", first.ndim()), + ) + .into()); + } + + output_shape[axis] = 0; + let mut dtype = first.dtype(); + for arr in arrays { + if arr.ndim() != first.ndim() || !same_shape_except_axis(arr.shape(), first.shape(), axis) { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() input arrays must have matching dimensions except along the concatenation axis"), + ) + .into()); + } + output_shape[axis] = output_shape[axis] + .checked_add(arr.shape()[axis]) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, format!("{name}() dimensions overflow")))?; + dtype = promote_dtype(dtype, arr.dtype()); + } + + let output_len = checked_shape_product(&output_shape, name)?; + check_array_alloc_size(output_len, tracker)?; + let inner = shape_product(&first.shape()[axis + 1..]); + let outer = shape_product(&first.shape()[..axis]); + let mut data = Vec::with_capacity(output_len); + for outer_index in 0..outer { + for arr in arrays { + let axis_len = arr.shape()[axis]; + let start = outer_index * axis_len * inner; + let end = start + axis_len * inner; + data.extend_from_slice(&arr.data()[start..end]); + } + } + + Ok(NdArray::new(data, output_shape, dtype)) +} + +/// Returns true when two shapes are equal outside one concatenation axis. +fn same_shape_except_axis(lhs: &[usize], rhs: &[usize], axis: usize) -> bool { + lhs.len() == rhs.len() + && lhs + .iter() + .zip(rhs.iter()) + .enumerate() + .all(|(index, (&left, &right))| index == axis || left == right) +} + +/// Removes one axis from a shape, preserving the order of the remaining axes. +fn shape_without_axis(shape: &[usize], axis: usize) -> Vec { + shape + .iter() + .enumerate() + .filter_map(|(index, &dim)| (index != axis).then_some(dim)) + .collect() +} + +/// Copies one half-open slice along an axis out of a row-major ndarray. +fn slice_ndarray_along_axis(arr: &NdArray, axis: usize, start_axis: usize, end_axis: usize) -> Vec { + let axis_len = arr.shape()[axis]; + let inner = shape_product(&arr.shape()[axis + 1..]); + let outer = shape_product(&arr.shape()[..axis]); + let chunk_axis_len = end_axis.saturating_sub(start_axis); + let mut data = Vec::with_capacity(outer * chunk_axis_len * inner); + for outer_index in 0..outer { + let block_start = outer_index * axis_len * inner + start_axis * inner; + let block_end = block_start + chunk_axis_len * inner; + data.extend_from_slice(&arr.data()[block_start..block_end]); + } + data +} + +/// `numpy.apply_along_axis(func1d, axis, arr, *args, **kwargs)` over materialized 1-D slices. +fn call_apply_along_axis(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(function) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.apply_along_axis", 3, 0)); + }; + defer_drop!(function, vm); + let Some(axis_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.apply_along_axis", 3, 1)); + }; + defer_drop!(axis_val, vm); + let Some(arr_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.apply_along_axis", 3, 2)); + }; + defer_drop!(arr_val, vm); + + let extra_args = pos.collect::>(); + defer_drop_mut!(extra_args, vm); + let kwargs_pairs = owned_kwargs_pairs(kwargs, vm)?; + defer_drop_mut!(kwargs_pairs, vm); + + let arr = ndarray_from_value(arr_val, "numpy.apply_along_axis", vm)?; + let axis = normalize_axis( + value_to_i64_arg(axis_val, "numpy.apply_along_axis", "axis")?, + arr.ndim(), + "numpy.apply_along_axis", + )?; + let result = apply_along_axis_array(function, extra_args, kwargs_pairs, &arr, axis, vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Executes `apply_along_axis` after arguments have been normalized. +fn apply_along_axis_array( + function: &Value, + extra_args: &[Value], + kwargs_pairs: &[(Value, Value)], + arr: &NdArray, + axis: usize, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let iteration_shape = shape_without_axis(arr.shape(), axis); + if iteration_shape.contains(&0) { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "Cannot apply_along_axis when any iteration dimensions are 0", + ) + .into()); + } + + let iteration_count = checked_shape_product(&iteration_shape, "numpy.apply_along_axis")?; + let mut output_shape = None::>; + let mut output_dtype = None::; + let mut output_data = Vec::::new(); + + for iteration_flat in 0..iteration_count { + let slice = apply_along_axis_slice(arr, axis, iteration_flat, &iteration_shape); + let slice_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(slice))?); + let result_value = call_user_function( + "numpy.apply_along_axis", + function, + vec![slice_value], + extra_args, + kwargs_pairs, + vm, + )?; + let result = value_to_owned_array_result(result_value, "numpy.apply_along_axis", vm)?; + + let expected_shape = + output_shape.get_or_insert_with(|| apply_along_axis_output_shape(arr.shape(), axis, result.shape())); + if expected_shape.as_slice() != apply_along_axis_output_shape(arr.shape(), axis, result.shape()).as_slice() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "numpy.apply_along_axis() function returned inconsistent shapes", + ) + .into()); + } + let current_dtype = output_dtype.unwrap_or(result.dtype()); + output_dtype = Some(promote_dtype(current_dtype, result.dtype())); + if output_data.is_empty() { + let output_len = checked_shape_product(expected_shape, "numpy.apply_along_axis")?; + check_array_alloc_size(output_len, vm.heap.tracker())?; + output_data.resize(output_len, 0.0); + } + fill_apply_along_axis_output( + &mut output_data, + expected_shape, + axis, + iteration_flat, + &iteration_shape, + &result, + ); + } + + Ok(NdArray::new( + output_data, + output_shape.unwrap_or_default(), + output_dtype.unwrap_or(arr.dtype()), + )) +} + +/// Builds one 1-D ndarray slice for a fixed coordinate outside the selected axis. +fn apply_along_axis_slice(arr: &NdArray, axis: usize, iteration_flat: usize, iteration_shape: &[usize]) -> NdArray { + let base_coords = flat_index_to_coords(iteration_flat, iteration_shape); + let axis_len = arr.shape()[axis]; + let mut data = Vec::with_capacity(axis_len); + for axis_coord in 0..axis_len { + let coords = full_coords_with_axis(axis, axis_coord, &base_coords, arr.ndim()); + let index = coords_to_flat_index(&coords, arr.shape()); + data.push(arr.data()[index]); + } + NdArray::new(data, vec![axis_len], arr.dtype()) +} + +/// Computes the result shape, replacing the iterated axis with the callable's result shape. +fn apply_along_axis_output_shape(input_shape: &[usize], axis: usize, result_shape: &[usize]) -> Vec { + let mut shape = Vec::with_capacity(input_shape.len().saturating_sub(1) + result_shape.len()); + shape.extend_from_slice(&input_shape[..axis]); + shape.extend_from_slice(result_shape); + shape.extend_from_slice(&input_shape[axis + 1..]); + shape +} + +/// Writes one callable result into the final `apply_along_axis` row-major output. +fn fill_apply_along_axis_output( + output_data: &mut [f64], + output_shape: &[usize], + axis: usize, + iteration_flat: usize, + iteration_shape: &[usize], + result: &NdArray, +) { + let base_coords = flat_index_to_coords(iteration_flat, iteration_shape); + let before = &base_coords[..axis]; + let after = &base_coords[axis..]; + for (result_flat, &value) in result.data().iter().enumerate() { + let result_coords = flat_index_to_coords(result_flat, result.shape()); + let mut output_coords = Vec::with_capacity(output_shape.len()); + output_coords.extend_from_slice(before); + output_coords.extend_from_slice(&result_coords); + output_coords.extend_from_slice(after); + let output_index = coords_to_flat_index(&output_coords, output_shape); + output_data[output_index] = value; + } +} + +/// Reconstructs full coordinates from coordinates with one axis removed. +fn full_coords_with_axis(axis: usize, axis_coord: usize, base_coords: &[usize], ndim: usize) -> Vec { + let mut coords = Vec::with_capacity(ndim); + coords.extend_from_slice(&base_coords[..axis]); + coords.push(axis_coord); + coords.extend_from_slice(&base_coords[axis..]); + coords +} + +/// `numpy.apply_over_axes(func, a, axes)` using kept-dimension reductions. +fn call_apply_over_axes(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.apply_over_axes", vm.heap)?; + defer_drop_mut!(pos, vm); + let function = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.apply_over_axes", 3, 0))?; + defer_drop!(function, vm); + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.apply_over_axes", 3, 1))?; + defer_drop!(arr_val, vm); + let axes_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.apply_over_axes", 3, 2))?; + defer_drop!(axes_val, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + for extra in pos { + extra.drop_with_heap(vm); + } + return Err(ExcType::type_error_at_most("numpy.apply_over_axes", 3, 4)); + } + + let mut current = ndarray_from_value(arr_val, "numpy.apply_over_axes", vm)?; + let axes = apply_over_axes_axis_list(axes_val, current.ndim(), vm)?; + for axis in axes { + current = apply_over_axes_one(function, ¤t, axis, vm)?; + } + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(current))?)) +} + +/// Parses the scalar-or-sequence axes accepted by `apply_over_axes`. +fn apply_over_axes_axis_list(value: &Value, ndim: usize, vm: &VM<'_, impl ResourceTracker>) -> RunResult> { + match value { + Value::Int(axis) => Ok(vec![normalize_axis(*axis, ndim, "numpy.apply_over_axes")?]), + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => axis_sequence_from_items(items.as_slice(), ndim, "numpy.apply_over_axes", "axes"), + HeapData::Tuple(items) => axis_sequence_from_items(items.as_slice(), ndim, "numpy.apply_over_axes", "axes"), + _ => Err(ExcType::type_error( + "numpy.apply_over_axes() axes must be an integer or sequence", + )), + }, + _ => Err(ExcType::type_error( + "numpy.apply_over_axes() axes must be an integer or sequence", + )), + } +} + +/// Applies one kept-axis reduction, using native reductions for known NumPy callables. +fn apply_over_axes_one( + function: &Value, + arr: &NdArray, + axis: usize, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + if let Some(reduction) = axis_reduction_from_callable(function) { + reduce_ndarray_axis(arr, axis, reduction, "numpy.apply_over_axes", vm.heap.tracker()) + } else { + let arr_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(arr.clone()))?); + let result_value = vm.evaluate_function( + "numpy.apply_over_axes", + function, + ArgValues::Two(arr_value, Value::Int(usize_to_i64(axis)?)), + )?; + let result = value_to_owned_array_result(result_value, "numpy.apply_over_axes", vm)?; + apply_over_axes_keep_axis(result, arr.ndim(), axis) + } +} + +/// Reductions that can be performed without routing back through a module call. +#[derive(Debug, Clone, Copy)] +enum AxisReduction { + Sum, + Prod, + Mean, + Min, + Max, + All, + Any, +} + +/// Maps public NumPy reduction functions to the kept-axis reduction core. +fn axis_reduction_from_callable(function: &Value) -> Option { + match function { + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Sum)) => Some(AxisReduction::Sum), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Prod)) => Some(AxisReduction::Prod), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Mean)) => Some(AxisReduction::Mean), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Min)) => Some(AxisReduction::Min), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Max)) => Some(AxisReduction::Max), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::All)) => Some(AxisReduction::All), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Any)) => Some(AxisReduction::Any), + _ => None, + } +} + +/// Reduces an ndarray along one axis while retaining that axis with length one. +fn reduce_ndarray_axis( + arr: &NdArray, + axis: usize, + reduction: AxisReduction, + name: &str, + tracker: &impl ResourceTracker, +) -> RunResult { + let mut output_shape = arr.shape().to_vec(); + let axis_len = output_shape[axis]; + output_shape[axis] = 1; + let output_len = checked_shape_product(&output_shape, name)?; + check_array_alloc_size(output_len, tracker)?; + let mut data = Vec::with_capacity(output_len); + for output_flat in 0..output_len { + let mut coords = flat_index_to_coords(output_flat, &output_shape); + data.push(reduce_axis_cell(arr, axis, axis_len, reduction, &mut coords)?); + } + let dtype = match reduction { + AxisReduction::Mean => NdArrayDtype::Float64, + AxisReduction::All | AxisReduction::Any => NdArrayDtype::Bool, + AxisReduction::Sum | AxisReduction::Prod | AxisReduction::Min | AxisReduction::Max => arr.dtype(), + }; + Ok(NdArray::new(data, output_shape, dtype)) +} + +/// Computes one kept-axis reduction cell by walking the removed coordinates. +fn reduce_axis_cell( + arr: &NdArray, + axis: usize, + axis_len: usize, + reduction: AxisReduction, + coords: &mut [usize], +) -> RunResult { + match reduction { + AxisReduction::Sum | AxisReduction::Mean => { + let mut total = 0.0; + for axis_coord in 0..axis_len { + coords[axis] = axis_coord; + total += arr.data()[coords_to_flat_index(coords, arr.shape())]; + } + if matches!(reduction, AxisReduction::Mean) { + Ok(total / axis_len as f64) + } else { + Ok(total) + } + } + AxisReduction::Prod => { + let mut product = 1.0; + for axis_coord in 0..axis_len { + coords[axis] = axis_coord; + product *= arr.data()[coords_to_flat_index(coords, arr.shape())]; + } + Ok(product) + } + AxisReduction::Min | AxisReduction::Max => reduce_axis_min_max(arr, axis, axis_len, reduction, coords), + AxisReduction::All => { + for axis_coord in 0..axis_len { + coords[axis] = axis_coord; + if arr.data()[coords_to_flat_index(coords, arr.shape())] == 0.0 { + return Ok(0.0); + } + } + Ok(1.0) + } + AxisReduction::Any => { + for axis_coord in 0..axis_len { + coords[axis] = axis_coord; + if arr.data()[coords_to_flat_index(coords, arr.shape())] != 0.0 { + return Ok(1.0); + } + } + Ok(0.0) + } + } +} + +/// Handles min/max reductions, which have no identity for empty axes. +fn reduce_axis_min_max( + arr: &NdArray, + axis: usize, + axis_len: usize, + reduction: AxisReduction, + coords: &mut [usize], +) -> RunResult { + if axis_len == 0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "zero-size array to reduction operation").into()); + } + coords[axis] = 0; + let mut best = arr.data()[coords_to_flat_index(coords, arr.shape())]; + for axis_coord in 1..axis_len { + coords[axis] = axis_coord; + let value = arr.data()[coords_to_flat_index(coords, arr.shape())]; + best = if matches!(reduction, AxisReduction::Min) { + best.min(value) + } else { + best.max(value) + }; + } + Ok(best) +} + +/// Ensures a generic callable result preserves the reduced axis as NumPy expects. +fn apply_over_axes_keep_axis(result: NdArray, target_ndim: usize, axis: usize) -> RunResult { + if result.ndim() == target_ndim { + if result.shape()[axis] == 1 { + Ok(result) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + "function is not returning an array of the correct shape", + ) + .into()) + } + } else if result.ndim() + 1 == target_ndim { + let mut shape = result.shape().to_vec(); + shape.insert(axis, 1); + Ok(NdArray::new(result.data().to_vec(), shape, result.dtype())) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + "function is not returning an array of the correct shape", + ) + .into()) + } +} + +/// `numpy.piecewise(x, condlist, funclist, *args, **kwargs)` for numeric arrays. +fn call_piecewise(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(x_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.piecewise", 3, 0)); + }; + defer_drop!(x_val, vm); + let Some(condlist_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.piecewise", 3, 1)); + }; + defer_drop!(condlist_val, vm); + let Some(funclist_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.piecewise", 3, 2)); + }; + defer_drop!(funclist_val, vm); + let extra_args = pos.collect::>(); + defer_drop_mut!(extra_args, vm); + let kwargs_pairs = owned_kwargs_pairs(kwargs, vm)?; + defer_drop_mut!(kwargs_pairs, vm); + + let x = ndarray_or_scalar_from_value(x_val, "numpy.piecewise", vm)?; + let cond_values = piecewise_values(condlist_val, "numpy.piecewise", vm)?; + defer_drop_mut!(cond_values, vm); + let fun_values = piecewise_values(funclist_val, "numpy.piecewise", vm)?; + defer_drop_mut!(fun_values, vm); + + let result = piecewise_array(&x, cond_values, fun_values, extra_args, kwargs_pairs, vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Extracts list/tuple values for `piecewise`, or wraps a single value as a one-item list. +fn piecewise_values(value: &Value, name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult> { + match value { + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(_) | HeapData::Tuple(_) => sequence_items(value, name, vm), + _ => Ok(vec![value.clone_with_heap(vm)]), + }, + _ => Ok(vec![value.clone_with_heap(vm)]), + } +} + +/// Replays the extra positional and keyword arguments passed to callable `piecewise` branches. +struct CallableReplay<'a> { + extra_args: &'a [Value], + kwargs_pairs: &'a [(Value, Value)], +} + +/// Applies condition/function pairs to produce a numeric `piecewise` result. +fn piecewise_array( + x: &NdArray, + cond_values: &[Value], + fun_values: &[Value], + extra_args: &[Value], + kwargs_pairs: &[(Value, Value)], + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + if fun_values.len() != cond_values.len() && fun_values.len() != cond_values.len() + 1 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "with 1 condition(s), either 1 or 2 functions are expected", + ) + .into()); + } + + let callable = CallableReplay { + extra_args, + kwargs_pairs, + }; + let mut data = vec![0.0; x.len()]; + let mut matched = vec![false; x.len()]; + let mut dtype = x.dtype(); + for (cond_value, fun_value) in cond_values.iter().zip(fun_values.iter()) { + let cond = piecewise_condition(cond_value, x.shape(), vm)?; + let selected = cond + .data() + .iter() + .enumerate() + .filter_map(|(index, &condition)| (condition != 0.0).then_some(index)) + .collect::>(); + for &index in &selected { + matched[index] = true; + } + dtype = piecewise_write_selection(&mut data, dtype, x, &selected, fun_value, &callable, vm)?; + } + + if let Some(default_value) = fun_values.get(cond_values.len()) { + let selected = matched + .iter() + .enumerate() + .filter_map(|(index, &is_matched)| (!is_matched).then_some(index)) + .collect::>(); + dtype = piecewise_write_selection(&mut data, dtype, x, &selected, default_value, &callable, vm)?; + } + + Ok(NdArray::new(data, x.shape().to_vec(), dtype)) +} + +/// Converts and broadcasts a `piecewise` condition to the input shape. +fn piecewise_condition(value: &Value, shape: &[usize], vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let condition = ndarray_or_scalar_from_value(value, "numpy.piecewise", vm)?; + let data = broadcast_array_data( + condition.data(), + condition.shape(), + shape, + "numpy.piecewise", + vm.heap.tracker(), + )?; + Ok(NdArray::new(data, shape.to_vec(), NdArrayDtype::Bool)) +} + +/// Writes scalar, array, or callable output into selected positions. +fn piecewise_write_selection( + output: &mut [f64], + current_dtype: NdArrayDtype, + x: &NdArray, + selected: &[usize], + fun_value: &Value, + callable: &CallableReplay<'_>, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + if selected.is_empty() { + Ok(current_dtype) + } else if is_callable_value(fun_value, vm) { + let values = selected.iter().map(|&index| x.data()[index]).collect::>(); + let selected_arr = NdArray::new(values, vec![selected.len()], x.dtype()); + let selected_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(selected_arr))?); + let result_value = call_user_function( + "numpy.piecewise", + fun_value, + vec![selected_value], + callable.extra_args, + callable.kwargs_pairs, + vm, + )?; + let result = value_to_owned_array_result(result_value, "numpy.piecewise", vm)?; + let values = piecewise_result_values(&result, selected.len())?; + for (&index, value) in selected.iter().zip(values.iter()) { + output[index] = *value; + } + Ok(promote_dtype(current_dtype, result.dtype())) + } else { + let choice = ndarray_or_scalar_from_value(fun_value, "numpy.piecewise", vm)?; + let choice_data = broadcast_array_data( + choice.data(), + choice.shape(), + x.shape(), + "numpy.piecewise", + vm.heap.tracker(), + )?; + for &index in selected { + output[index] = choice_data[index]; + } + Ok(promote_dtype(current_dtype, choice.dtype())) + } +} + +/// Normalizes callable results for assignment to selected `piecewise` slots. +fn piecewise_result_values(result: &NdArray, selected_len: usize) -> RunResult> { + if result.shape().is_empty() { + Ok(vec![result.data()[0]; selected_len]) + } else if result.len() == selected_len { + Ok(result.data().to_vec()) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + "NumPy boolean array indexing assignment cannot assign input values to selected output values", + ) + .into()) + } +} + +/// `numpy.pad(array, pad_width, mode='constant', **kwargs)` materialized for common safe modes. +fn call_pad(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (mut pos, kwargs) = args.into_parts(); + let Some(array_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.pad", 2, 0)); + }; + defer_drop!(array_val, vm); + let Some(pad_width_val) = pos.next() else { + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_least("numpy.pad", 2, 1)); + }; + defer_drop!(pad_width_val, vm); + let mode_pos = pos.next(); + defer_drop_mut!(mode_pos, vm); + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + pos.drop_with_heap(vm); + kwargs.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.pad", 3, 4)); + } + pos.drop_with_heap(vm); + + let arr = ndarray_or_scalar_from_value(array_val, "numpy.pad", vm)?; + let pad_width = parse_pad_width(pad_width_val, arr.ndim(), "numpy.pad", vm)?; + let options = parse_pad_options(mode_pos.as_ref(), kwargs, arr.ndim(), vm)?; + let result = pad_array(&arr, &pad_width, &options, vm.heap.tracker())?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// One axis of pad widths or constant values. +#[derive(Debug, Clone, Copy)] +struct PadPair { + /// Values inserted before the original axis. + before: T, + /// Values inserted after the original axis. + after: T, +} + +/// Supported pure padding modes that fit Monty's owned ndarray model. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum PadMode { + Constant, + Edge, + Reflect, + Symmetric, + Wrap, +} + +/// Parsed `pad()` options. +struct PadOptions { + mode: PadMode, + constant_values: Vec>, + constant_dtype: NdArrayDtype, +} + +/// Parses `pad()` keyword arguments and validates the supported option surface. +fn parse_pad_options( + mode_pos: Option<&Value>, + kwargs: KwargsValues, + ndim: usize, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let mut mode = None::; + if let Some(value) = mode_pos { + mode = Some(parse_pad_mode(value, vm)?); + } + let mut constant_values = None::<(Vec>, NdArrayDtype)>; + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + for (key, value) in kwargs_iter { + let Some(keyword_name) = key.as_either_str(vm.heap) else { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns).to_string(); + defer_drop!(key, vm); + let mut value = HeapGuard::new(value, vm); + if key_str == "mode" { + if mode.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.pad", "mode")); + } + let (value, vm) = value.as_parts(); + mode = Some(parse_pad_mode(value, vm)?); + } else if key_str == "constant_values" { + if constant_values.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.pad", "constant_values")); + } + let (value, vm) = value.as_parts(); + constant_values = Some(parse_pad_numeric_pairs(value, ndim, "numpy.pad", vm)?); + } else { + return Err(ExcType::type_error_unexpected_keyword("numpy.pad", &key_str)); + } + } + + let mode = mode.unwrap_or(PadMode::Constant); + let (constant_values, constant_dtype) = constant_values.unwrap_or_else(|| { + ( + vec![ + PadPair { + before: 0.0, + after: 0.0, + }; + ndim + ], + NdArrayDtype::Int64, + ) + }); + Ok(PadOptions { + mode, + constant_values, + constant_dtype, + }) +} + +/// Parses one supported `mode=` value. +fn parse_pad_mode(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let Some(mode) = value.as_either_str(vm.heap) else { + return Err(ExcType::type_error("numpy.pad() mode must be a string")); + }; + match mode.as_str(vm.interns) { + "constant" => Ok(PadMode::Constant), + "edge" => Ok(PadMode::Edge), + "reflect" => Ok(PadMode::Reflect), + "symmetric" => Ok(PadMode::Symmetric), + "wrap" => Ok(PadMode::Wrap), + other => Err(SimpleException::new_msg(ExcType::ValueError, format!("mode '{other}' is not supported")).into()), + } +} + +/// Parses integer pad widths in NumPy's scalar, pair, or per-axis pair forms. +fn parse_pad_width( + value: &Value, + ndim: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult>> { + match value { + Value::Int(width) => { + let width = i64_to_nonnegative_usize(*width, name, "pad_width")?; + Ok(vec![ + PadPair { + before: width, + after: width, + }; + ndim + ]) + } + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => parse_pad_width_items(items.as_slice(), ndim, name, vm), + HeapData::Tuple(items) => parse_pad_width_items(items.as_slice(), ndim, name, vm), + _ => Err(ExcType::type_error( + "numpy.pad() pad_width must be an integer or sequence", + )), + }, + _ => Err(ExcType::type_error( + "numpy.pad() pad_width must be an integer or sequence", + )), + } +} + +/// Parses numeric constant values in NumPy's scalar, pair, or per-axis pair forms. +fn parse_pad_numeric_pairs( + value: &Value, + ndim: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(Vec>, NdArrayDtype)> { + if let Ok((value, dtype)) = numeric_scalar_info(value, name, vm) { + return Ok(( + vec![ + PadPair { + before: value, + after: value, + }; + ndim + ], + dtype, + )); + } + let Value::Ref(heap_id) = value else { + return Err(ExcType::type_error("numpy.pad() constant_values must be numeric")); + }; + match vm.heap.get(*heap_id) { + HeapData::List(items) => parse_pad_numeric_items(items.as_slice(), ndim, name, vm), + HeapData::Tuple(items) => parse_pad_numeric_items(items.as_slice(), ndim, name, vm), + _ => Err(ExcType::type_error("numpy.pad() constant_values must be numeric")), + } +} + +/// Parses sequence pad-width forms. +fn parse_pad_width_items( + items: &[Value], + ndim: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult>> { + if items.len() == 2 && items.iter().all(|item| matches!(item, Value::Int(_))) { + let pair = PadPair { + before: value_to_nonnegative_usize(&items[0], name, "pad_width")?, + after: value_to_nonnegative_usize(&items[1], name, "pad_width")?, + }; + Ok(vec![pair; ndim]) + } else if items.len() == ndim { + items + .iter() + .map(|item| parse_one_pad_width_item(item, name, vm)) + .collect() + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "operands could not be broadcast together").into()) + } +} + +/// Parses one per-axis pad-width item. +fn parse_one_pad_width_item(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult> { + match value { + Value::Int(width) => { + let width = i64_to_nonnegative_usize(*width, name, "pad_width")?; + Ok(PadPair { + before: width, + after: width, + }) + } + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => parse_one_pad_width_pair(items.as_slice(), name), + HeapData::Tuple(items) => parse_one_pad_width_pair(items.as_slice(), name), + _ => Err(ExcType::type_error("numpy.pad() pad_width must contain integers")), + }, + _ => Err(ExcType::type_error("numpy.pad() pad_width must contain integers")), + } +} + +/// Parses one two-item pad-width pair. +fn parse_one_pad_width_pair(items: &[Value], name: &str) -> RunResult> { + match items { + [before, after] => Ok(PadPair { + before: value_to_nonnegative_usize(before, name, "pad_width")?, + after: value_to_nonnegative_usize(after, name, "pad_width")?, + }), + _ => Err(SimpleException::new_msg(ExcType::ValueError, "operands could not be broadcast together").into()), + } +} + +/// Parses sequence constant-value forms and tracks their compact dtype. +fn parse_pad_numeric_items( + items: &[Value], + ndim: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(Vec>, NdArrayDtype)> { + if items.len() == 2 && items.iter().all(|item| numeric_scalar_info(item, name, vm).is_ok()) { + let (before, before_dtype) = numeric_scalar_info(&items[0], name, vm)?; + let (after, after_dtype) = numeric_scalar_info(&items[1], name, vm)?; + let dtype = promote_dtype(before_dtype, after_dtype); + Ok((vec![PadPair { before, after }; ndim], dtype)) + } else if items.len() == ndim { + let mut dtype = NdArrayDtype::Int64; + let pairs = items + .iter() + .map(|item| { + let (pair, pair_dtype) = parse_one_pad_numeric_item(item, name, vm)?; + dtype = promote_dtype(dtype, pair_dtype); + Ok(pair) + }) + .collect::>>()?; + Ok((pairs, dtype)) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "operands could not be broadcast together").into()) + } +} + +/// Parses one per-axis constant-value item. +fn parse_one_pad_numeric_item( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(PadPair, NdArrayDtype)> { + if let Ok((value, dtype)) = numeric_scalar_info(value, name, vm) { + Ok(( + PadPair { + before: value, + after: value, + }, + dtype, + )) + } else { + let Value::Ref(heap_id) = value else { + return Err(ExcType::type_error("numpy.pad() constant_values must contain numbers")); + }; + match vm.heap.get(*heap_id) { + HeapData::List(items) => parse_one_pad_numeric_pair(items.as_slice(), name, vm), + HeapData::Tuple(items) => parse_one_pad_numeric_pair(items.as_slice(), name, vm), + _ => Err(ExcType::type_error("numpy.pad() constant_values must contain numbers")), + } + } +} + +/// Parses one two-item constant-value pair. +fn parse_one_pad_numeric_pair( + items: &[Value], + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(PadPair, NdArrayDtype)> { + match items { + [before, after] => { + let (before, before_dtype) = numeric_scalar_info(before, name, vm)?; + let (after, after_dtype) = numeric_scalar_info(after, name, vm)?; + Ok((PadPair { before, after }, promote_dtype(before_dtype, after_dtype))) + } + _ => Err(SimpleException::new_msg(ExcType::ValueError, "operands could not be broadcast together").into()), + } +} + +/// Pads an ndarray by materializing the requested output array. +fn pad_array( + arr: &NdArray, + pad_width: &[PadPair], + options: &PadOptions, + tracker: &impl ResourceTracker, +) -> RunResult { + let output_shape = padded_shape(arr.shape(), pad_width)?; + let output_len = checked_shape_product(&output_shape, "numpy.pad")?; + check_array_alloc_size(output_len, tracker)?; + + if options.mode != PadMode::Constant && arr.shape().contains(&0) { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "can't extend empty axis using modes other than 'constant'", + ) + .into()); + } + + let mut data = Vec::with_capacity(output_len); + for output_flat in 0..output_len { + let output_coords = flat_index_to_coords(output_flat, &output_shape); + data.push(pad_value_at(arr, pad_width, options, &output_coords)?); + } + let dtype = if options.mode == PadMode::Constant { + promote_dtype(arr.dtype(), options.constant_dtype) + } else { + arr.dtype() + }; + Ok(NdArray::new(data, output_shape, dtype)) +} + +/// Computes a checked padded shape. +fn padded_shape(shape: &[usize], pad_width: &[PadPair]) -> RunResult> { + shape + .iter() + .zip(pad_width.iter()) + .map(|(&dimension, pad)| { + dimension + .checked_add(pad.before) + .and_then(|value| value.checked_add(pad.after)) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.pad() dimensions overflow").into()) + }) + .collect() +} + +/// Returns one padded value, either from the source array or from a constant edge. +fn pad_value_at( + arr: &NdArray, + pad_width: &[PadPair], + options: &PadOptions, + output_coords: &[usize], +) -> RunResult { + let mut input_coords = Vec::with_capacity(arr.ndim()); + let mut constant = None::; + for (axis, (&coord, (&dimension, pad))) in output_coords + .iter() + .zip(arr.shape().iter().zip(pad_width.iter())) + .enumerate() + { + let raw = usize_to_i64(coord)? - usize_to_i64(pad.before)?; + let dimension_i64 = usize_to_i64(dimension)?; + if options.mode == PadMode::Constant && raw < 0 { + constant = Some(options.constant_values[axis].before); + input_coords.push(0); + } else if options.mode == PadMode::Constant && raw >= dimension_i64 { + constant = Some(options.constant_values[axis].after); + input_coords.push(dimension.saturating_sub(1)); + } else { + input_coords.push(padded_source_index(raw, dimension, options.mode)?); + } + } + Ok(constant.unwrap_or_else(|| arr.data()[coords_to_flat_index(&input_coords, arr.shape())])) +} + +/// Maps one padded coordinate back into the source axis for non-constant modes. +fn padded_source_index(raw: i64, len: usize, mode: PadMode) -> RunResult { + if len <= 1 { + Ok(0) + } else { + let len_i64 = usize_to_i64(len)?; + match mode { + PadMode::Constant | PadMode::Edge => i64_to_pad_index(raw.clamp(0, len_i64 - 1)), + PadMode::Reflect => reflected_index(raw, len, false), + PadMode::Symmetric => reflected_index(raw, len, true), + PadMode::Wrap => i64_to_pad_index(raw.rem_euclid(len_i64)), + } + } +} + +/// Maps an integer coordinate through NumPy-style reflect or symmetric padding. +fn reflected_index(raw: i64, len: usize, symmetric: bool) -> RunResult { + let len = usize_to_i64(len)?; + let period = if symmetric { len * 2 } else { len * 2 - 2 }; + let coord = raw.rem_euclid(period); + if symmetric { + if coord >= len { + i64_to_pad_index(period - coord - 1) + } else { + i64_to_pad_index(coord) + } + } else if coord >= len { + i64_to_pad_index(period - coord) + } else { + i64_to_pad_index(coord) + } +} + +/// Converts an internal pad coordinate into a checked source-array index. +fn i64_to_pad_index(value: i64) -> RunResult { + usize::try_from(value) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "numpy.pad() index is too large").into()) +} + +/// Collects owned keyword pairs so callable helpers can replay them for each call. +fn owned_kwargs_pairs(kwargs: KwargsValues, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult> { + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + let pairs = Vec::<(Value, Value)>::new(); + let mut pairs_guard = HeapGuard::new(pairs, vm); + let (pairs, vm) = pairs_guard.as_parts_mut(); + for (key, value) in kwargs_iter { + if key.as_either_str(vm.heap).is_none() { + key.drop_with_heap(vm); + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + } + pairs.push((key, value)); + } + Ok(pairs_guard.into_inner()) +} + +/// Calls a user function with a freshly cloned argument and keyword list. +fn call_user_function( + ctx: &'static str, + function: &Value, + mut args: Vec, + extra_args: &[Value], + kwargs_pairs: &[(Value, Value)], + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + for arg in extra_args { + args.push(arg.clone_with_heap(vm)); + } + let kwargs = cloned_kwargs_from_pairs(kwargs_pairs, vm)?; + vm.evaluate_function(ctx, function, args_from_vec_and_kwargs(args, kwargs)) +} + +/// Recreates keyword arguments for one callable invocation. +fn cloned_kwargs_from_pairs( + pairs: &[(Value, Value)], + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let cloned = pairs + .iter() + .map(|(key, value)| (key.clone_with_heap(vm), value.clone_with_heap(vm))) + .collect::>(); + kwargs_from_pairs(cloned, vm) +} + +/// Converts a callable return value into an owned ndarray and drops the original value. +fn value_to_owned_array_result(value: Value, name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let value = Some(value); + defer_drop_mut!(value, vm); + ndarray_or_scalar_from_value(value.as_ref().expect("call result is present"), name, vm) +} + +/// Returns whether a value can be called by Monty's function-call machinery. +fn is_callable_value(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> bool { + match value { + Value::DefFunction(_) | Value::Builtin(_) | Value::ExtFunction(_) | Value::ModuleFunction(_) => true, + Value::Ref(heap_id) => matches!( + vm.heap.get(*heap_id), + HeapData::Closure(_) | HeapData::FunctionDefaults(_) | HeapData::ExtFunction(_) + ), + _ => false, + } +} + +/// Computes a small shape product for already-validated ndarray dimensions. +fn shape_product(shape: &[usize]) -> usize { + shape.iter().product() +} + +/// `numpy.nonzero(a)` — indices of non-zero elements, returned as a tuple of arrays. +fn call_nonzero(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.nonzero", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.nonzero", vm)?; + + let indices: Vec = arr + .data() + .iter() + .enumerate() + .filter(|&(_, v)| *v != 0.0) + .map(|(i, _)| i as f64) + .collect(); + + let len = indices.len(); + let idx_arr = NdArray::new(indices, vec![len], NdArrayDtype::Int64); + let idx_val = Value::Ref(vm.heap.allocate(HeapData::NdArray(idx_arr))?); + + // NumPy returns a tuple with one array per dimension. For 1D input, it's a 1-tuple. + // Note: if allocate_tuple fails (resource limit), idx_val may be leaked. This is + // acceptable per project convention — resource exhaustion is a terminal error. + let values: SmallVec<[Value; 3]> = smallvec::smallvec![idx_val]; + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// `numpy.argwhere(a)` — indices where elements are non-zero, as 2D array. +fn call_argwhere(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.argwhere", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.argwhere", vm)?; + + let indices: Vec = arr + .data() + .iter() + .enumerate() + .filter(|&(_, v)| *v != 0.0) + .map(|(i, _)| i as f64) + .collect(); + + let rows = indices.len(); + // For 1D input, argwhere returns shape (n_nonzero, 1) + let result = NdArray::new(indices, vec![rows, 1], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.tile(a, reps)` — construct array by repeating `a` `reps` times. +fn call_tile(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, reps_val) = args.get_two_args("numpy.tile", vm.heap)?; + defer_drop!(arr_val, vm); + + let arr = ndarray_from_value(arr_val, "numpy.tile", vm)?; + defer_drop!(reps_val, vm); + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "reps checked non-negative" + )] + let reps = if let Value::Int(n) = reps_val { + if *n < 0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "negative number of repetitions").into()); + } + *n as usize + } else { + return Err(ExcType::type_error("numpy.tile() reps must be an integer")); + }; + + if reps == 0 || arr.len() == 0 { + let result = NdArray::new(Vec::new(), vec![0], arr.dtype()); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + + let total = arr.len() * reps; + check_array_alloc_size(total, vm.heap.tracker())?; + let data: Vec = arr.data().iter().copied().cycle().take(total).collect(); + let result = NdArray::new(data, vec![total], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.repeat(a, repeats)` — repeat each element `repeats` times. +fn call_repeat(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, reps_val) = args.get_two_args("numpy.repeat", vm.heap)?; + defer_drop!(arr_val, vm); + + let arr = ndarray_from_value(arr_val, "numpy.repeat", vm)?; + defer_drop!(reps_val, vm); + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "repeats checked non-negative" + )] + let reps = if let Value::Int(n) = reps_val { + if *n < 0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "negative number of repetitions").into()); + } + *n as usize + } else { + return Err(ExcType::type_error("numpy.repeat() repeats must be an integer")); + }; + + if arr.len() == 0 { + let result = NdArray::new(Vec::new(), vec![0], arr.dtype()); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + + let total = arr.len() * reps; + check_array_alloc_size(total, vm.heap.tracker())?; + let mut data = Vec::with_capacity(total); + for &v in arr.data() { + for _ in 0..reps { + data.push(v); + } + } + let result = NdArray::new(data, vec![total], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.split(a, indices_or_sections)` — split array into sub-arrays. +/// +/// If the second argument is an integer, splits into that many equal parts. +/// If it's a list/array, splits at the given indices. +fn call_split(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, idx_val) = args.get_two_args("numpy.split", vm.heap)?; + defer_drop!(arr_val, vm); + + let arr = ndarray_from_value(arr_val, "numpy.split", vm)?; + let data = arr.data(); + let dtype = arr.dtype(); + + // Determine split points + let split_indices: Vec = match &idx_val { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "sections checked > 0" + )] + Value::Int(n) => { + if *n <= 0 { + idx_val.drop_with_heap(vm); + return Err( + SimpleException::new_msg(ExcType::ValueError, "number sections must be larger than 0").into(), + ); + } + let sections = *n as usize; + if data.len() % sections != 0 { + idx_val.drop_with_heap(vm); + return Err(SimpleException::new_msg( + ExcType::ValueError, + "array split does not result in an equal division", + ) + .into()); + } + let chunk_size = data.len() / sections; + (1..sections).map(|i| i * chunk_size).collect() + } + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => list + .as_slice() + .iter() + .map(|v| match v { + #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation, reason = "index from user")] + Value::Int(n) => Ok(*n as usize), + _ => Err(ExcType::type_error("split indices must be integers")), + }) + .collect::>>()?, + HeapData::NdArray(idx_arr) => + { + #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation, reason = "index from user")] + idx_arr.data().iter().map(|&v| v as usize).collect() + } + _ => { + idx_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.split() second arg must be int or list")); + } + }, + _ => { + idx_val.drop_with_heap(vm); + return Err(ExcType::type_error("numpy.split() second arg must be int or list")); + } + }; + idx_val.drop_with_heap(vm); + + // Build sub-arrays. Note: if allocation fails partway through, previously allocated + // sub-arrays in `parts` are leaked. This is acceptable — allocation failure is a + // terminal resource-limit error (see CLAUDE.md reference counting docs). + let mut parts = Vec::new(); + let mut prev = 0; + for &idx in &split_indices { + let end = idx.min(data.len()); + let chunk = data[prev..end].to_vec(); + let len = chunk.len(); + parts.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + chunk, + vec![len], + dtype, + )))?)); + prev = end; + } + // Last chunk + let chunk = data[prev..].to_vec(); + let len = chunk.len(); + parts.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + chunk, + vec![len], + dtype, + )))?)); + + let list = List::new(parts); + Ok(Value::Ref(vm.heap.allocate(HeapData::List(list))?)) +} + +/// `numpy.shape(a)` — return the dimensions of an array-like value. +fn call_shape(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.shape", vm.heap)?; + defer_drop!(arg, vm); + let shape = array_like_shape(arg, "numpy.shape", vm)?; + #[expect(clippy::cast_possible_wrap, reason = "shape dimensions won't exceed i64::MAX")] + let values: SmallVec<[Value; 3]> = shape.iter().map(|&d| Value::Int(d as i64)).collect(); + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// `numpy.size(a)` — return the total number of elements in an array-like value. +fn call_size(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.size", vm.heap)?; + defer_drop!(arg, vm); + let shape = array_like_shape(arg, "numpy.size", vm)?; + let size = shape.iter().product::(); + #[expect(clippy::cast_possible_wrap, reason = "array sizes are resource-limited")] + Ok(Value::Int(size as i64)) +} + +/// `numpy.ndim(a)` — return the number of dimensions in an array-like value. +fn call_ndim(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.ndim", vm.heap)?; + defer_drop!(arg, vm); + let shape = array_like_shape(arg, "numpy.ndim", vm)?; + #[expect(clippy::cast_possible_wrap, reason = "ndim is always small")] + Ok(Value::Int(shape.len() as i64)) +} + +/// `numpy.broadcast_shapes(*shapes)` — return the common NumPy broadcast shape. +fn call_broadcast_shapes(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.broadcast_shapes", vm.heap)?; + defer_drop_mut!(pos, vm); + + let mut shapes = Vec::new(); + for shape_value in pos.by_ref() { + defer_drop!(shape_value, vm); + shapes.push(extract_shape_from_value(shape_value, "numpy.broadcast_shapes", vm)?); + } + + let shape_refs: Vec<&[usize]> = shapes.iter().map(Vec::as_slice).collect(); + let shape = broadcast_shape(&shape_refs, "numpy.broadcast_shapes").map_err(|error| { + if shapes.len() == 2 { + broadcast_shapes_value_error(&shapes) + } else { + error + } + })?; + shape_to_tuple(&shape, vm.heap) +} + +/// `numpy.broadcast_to(array, shape)` — materialize an array in a broadcast shape. +fn call_broadcast_to(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (array_value, shape_value) = args.get_two_args("numpy.broadcast_to", vm.heap)?; + defer_drop!(array_value, vm); + defer_drop!(shape_value, vm); + + let array = ndarray_or_scalar_from_value(array_value, "numpy.broadcast_to", vm)?; + let shape = extract_shape_from_value(shape_value, "numpy.broadcast_to", vm)?; + let data = broadcast_array_data( + array.data(), + array.shape(), + &shape, + "numpy.broadcast_to", + vm.heap.tracker(), + )?; + let result = NdArray::new(data, shape, array.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.broadcast_arrays(*arrays)` — return arrays materialized in a shared shape. +fn call_broadcast_arrays(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arrays, shape) = broadcast_positional_arrays(vm, args, "numpy.broadcast_arrays")?; + let total = checked_shape_product(&shape, "numpy.broadcast_arrays")?; + check_array_alloc_size(total.saturating_mul(arrays.len()), vm.heap.tracker())?; + + let mut values = Vec::with_capacity(arrays.len()); + for array in arrays { + let data = broadcast_array_data( + array.data(), + array.shape(), + &shape, + "numpy.broadcast_arrays", + vm.heap.tracker(), + )?; + let result = NdArray::new(data, shape.clone(), array.dtype()); + values.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + + let values: SmallVec<[Value; 3]> = values.into_iter().collect(); + Ok(allocate_tuple(values, vm.heap)?) +} + +/// `numpy.broadcast(*arrays)` — materialized iterable subset of NumPy's broadcast object. +/// +/// Real NumPy returns a `numpy.broadcast` object with shape metadata and lazy +/// iteration. Monty does not yet have a dedicated broadcast object type, so this +/// sandbox-safe subset returns the same iteration payload as a list of tuples. +fn call_broadcast(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arrays, shape) = broadcast_positional_arrays(vm, args, "numpy.broadcast")?; + let total = checked_shape_product(&shape, "numpy.broadcast")?; + check_array_alloc_size(total.saturating_mul(arrays.len()), vm.heap.tracker())?; + + let mut materialized = Vec::with_capacity(arrays.len()); + for array in &arrays { + materialized.push(broadcast_array_data( + array.data(), + array.shape(), + &shape, + "numpy.broadcast", + vm.heap.tracker(), + )?); + } + + let mut rows = Vec::with_capacity(total); + for index in 0..total { + let mut items = SmallVec::new(); + for (array, data) in arrays.iter().zip(materialized.iter()) { + items.push(scalar_from_f64(data[index], array.dtype())); + } + rows.push(allocate_tuple(items, vm.heap)?); + } + + Ok(Value::Ref(vm.heap.allocate(HeapData::List(List::new(rows)))?)) +} + +/// Converts a shape into a Python tuple of integers. +fn shape_to_tuple(shape: &[usize], heap: &Heap) -> RunResult { + let mut values = SmallVec::new(); + for &dimension in shape { + values.push(Value::Int(usize_to_i64(dimension)?)); + } + Ok(allocate_tuple(values, heap)?) +} + +/// Parses positional array-like inputs and computes their common broadcast shape. +fn broadcast_positional_arrays( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + name: &str, +) -> RunResult<(Vec, Vec)> { + let pos = args.into_pos_only(name, vm.heap)?; + defer_drop_mut!(pos, vm); + + let mut arrays = Vec::new(); + for value in pos.by_ref() { + defer_drop!(value, vm); + arrays.push(ndarray_or_scalar_from_value(value, name, vm)?); + } + + let shape_refs: Vec<&[usize]> = arrays.iter().map(NdArray::shape).collect(); + let shape = broadcast_shape(&shape_refs, name)?; + Ok((arrays, shape)) +} + +/// Builds NumPy's detailed public `broadcast_shapes()` mismatch message. +fn broadcast_shapes_value_error(shapes: &[Vec]) -> RunError { + SimpleException::new_msg( + ExcType::ValueError, + format!( + "shape mismatch: objects cannot be broadcast to a single shape. Mismatch is between arg 0 with shape {} and arg 1 with shape {}.", + format_public_shape(&shapes[0]), + format_public_shape(&shapes[1]) + ), + ) + .into() +} + +/// Formats a shape using Python tuple display spacing. +fn format_public_shape(shape: &[usize]) -> String { + match shape { + [] => "()".to_string(), + [dim] => format!("({dim},)"), + _ => { + let mut formatted = String::from("("); + for (index, dim) in shape.iter().enumerate() { + if index > 0 { + formatted.push_str(", "); + } + formatted.push_str(&dim.to_string()); + } + formatted.push(')'); + formatted + } + } +} + +/// Returns the shape for ndarray/list inputs and the scalar shape for numbers. +fn array_like_shape(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult> { + if let Ok((_, shape, _)) = extract_ndarray_info(value, name, vm) { + Ok(shape) + } else { + numeric_scalar_info(value, name, vm)?; + Ok(Vec::new()) + } +} + +// =========================== +// Utility helpers +// =========================== + +/// Extracts ndarray data from a Value, auto-converting lists. +/// +/// Returns (data, shape, dtype) tuple — copies data out to avoid lifetime issues. +/// Uses `ndarray_from_list` for lists so dtype tracking (int vs float vs bool) is correct. +fn extract_ndarray_info( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(Vec, Vec, NdArrayDtype)> { + match value { + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::NdArray(arr) => Ok((arr.data().to_vec(), arr.shape().to_vec(), arr.dtype())), + HeapData::List(_) => { + let tmp = ndarray_from_list(value, vm.heap)?; + Ok((tmp.data().to_vec(), tmp.shape().to_vec(), tmp.dtype())) + } + _ => Err(ExcType::type_error(format!( + "{name}() requires an array or list argument" + ))), + }, + _ => Err(ExcType::type_error(format!( + "{name}() requires an array or list argument" + ))), + } +} + +/// Convenience wrapper that returns an NdArray (owned). +fn ndarray_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let (data, shape, dtype) = extract_ndarray_info(value, name, vm)?; + Ok(NdArray::new(data, shape, dtype)) +} + +/// Converts an ndarray/list input or a numeric scalar into an owned ndarray. +/// +/// NumPy treats scalar inputs as zero-dimensional arrays for iterator-style +/// helpers. Keeping this conversion local avoids broadening `ndarray_from_value`, +/// whose stricter array/list contract is relied on by many existing functions. +fn ndarray_or_scalar_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + if let Ok((data, shape, dtype)) = extract_ndarray_info(value, name, vm) { + Ok(NdArray::new(data, shape, dtype)) + } else { + let (scalar, dtype) = numeric_scalar_info(value, name, vm)?; + Ok(NdArray::new(vec![scalar], Vec::new(), dtype)) + } +} + +/// Extracts a shape from a Value — supports int (1D), list, or tuple. +fn extract_shape(value: Value, func_name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult> { + match &value { + Value::Int(_) => { + let n = extract_size(value, func_name, vm)?; + Ok(vec![n]) + } + Value::Ref(heap_id) => { + let shape = match vm.heap.get(*heap_id) { + HeapData::List(list) => extract_shape_from_items(list.as_slice(), func_name)?, + HeapData::Tuple(tuple) => extract_shape_from_items(tuple.as_slice(), func_name)?, + _ => { + value.drop_with_heap(vm); + return Err(ExcType::type_error(format!( + "{func_name}() requires an integer or tuple of integers" + ))); + } + }; + value.drop_with_heap(vm); + Ok(shape) + } + _ => { + value.drop_with_heap(vm); + Err(ExcType::type_error(format!( + "{func_name}() requires an integer or tuple of integers" + ))) + } + } +} + +/// Extracts shape from a Value without consuming it (for reshape where we borrow). +fn extract_shape_from_value( + value: &Value, + func_name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + match value { + #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation, reason = "shape from user")] + Value::Int(n) if *n >= 0 => Ok(vec![*n as usize]), + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(list) => extract_shape_from_items(list.as_slice(), func_name), + HeapData::Tuple(tuple) => extract_shape_from_items(tuple.as_slice(), func_name), + _ => Err(ExcType::type_error(format!( + "{func_name}() requires an integer or tuple of integers" + ))), + }, + _ => Err(ExcType::type_error(format!( + "{func_name}() requires an integer or tuple of integers" + ))), + } +} + +/// Extracts a shape vector from a slice of Values (list or tuple items). +fn extract_shape_from_items(items: &[Value], func_name: &str) -> RunResult> { + items + .iter() + .map(|v| match v { + #[expect(clippy::cast_sign_loss, clippy::cast_possible_truncation, reason = "shape from user")] + Value::Int(n) if *n >= 0 => Ok(*n as usize), + _ => Err(ExcType::type_error(format!( + "{func_name}() shape must contain non-negative integers" + ))), + }) + .collect() +} + +/// Extracts an integer size from a Value for array creation functions. +fn extract_size(value: Value, func_name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + match value { + #[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "n is guaranteed non-negative" + )] + Value::Int(n) if n >= 0 => Ok(n as usize), + Value::Int(_) => Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{func_name}(): negative dimensions are not allowed"), + ) + .into()), + _ => { + value.drop_with_heap(vm); + Err(ExcType::type_error(format!( + "{func_name}() requires an integer argument" + ))) + } + } +} + +/// Converts a Value to f64 for numeric operations. +fn to_f64(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::Int(n) => Ok(*n as f64), + Value::Float(f) => Ok(*f), + Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }), + _ => Err(ExcType::type_error(format!( + "a number is required, not '{}'", + value.py_type(vm) + ))), + } +} + +/// Converts a Python numeric scalar to the internal f64 value plus NumPy dtype. +/// +/// This is used by scalar-compatible ufunc-style helpers, where real NumPy +/// accepts both arrays and scalars. Non-numeric values still raise the same +/// Monty type error style as the array path. +fn numeric_scalar_info(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult<(f64, NdArrayDtype)> { + match value { + Value::Int(n) => Ok((*n as f64, NdArrayDtype::Int64)), + Value::Float(f) => Ok((*f, NdArrayDtype::Float64)), + Value::Bool(b) => Ok((if *b { 1.0 } else { 0.0 }, NdArrayDtype::Bool)), + _ => Err(ExcType::type_error(format!( + "{name}() requires an array, list, or scalar argument, not '{}'", + value.py_type(vm) + ))), + } +} + +/// Converts an internal f64 result back to the best scalar value for a dtype. +/// +/// Integer and boolean scalar results mirror Monty's existing ndarray display +/// conversion: the f64 backing value is truncated for integer dtypes and +/// non-zero values are truthy for boolean dtypes. +fn scalar_from_f64(value: f64, dtype: NdArrayDtype) -> Value { + match dtype { + #[expect( + clippy::cast_possible_truncation, + reason = "scalar conversion follows ndarray integer element conversion" + )] + NdArrayDtype::Int64 => Value::Int(value as i64), + NdArrayDtype::Float64 => Value::Float(value), + NdArrayDtype::Bool => Value::Bool(value != 0.0), + } +} + +/// `numpy.angle(z, deg=False)` for Monty's real-valued numeric subset. +fn call_angle(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg, deg_val) = args.get_one_two_args("numpy.angle", vm.heap)?; + defer_drop!(arg, vm); + let deg = if let Some(deg_val) = deg_val { + defer_drop!(deg_val, vm); + value_to_bool_arg(deg_val, "numpy.angle", "deg")? + } else { + false + }; + + if let Ok((data, shape, _)) = extract_ndarray_info(arg, "numpy.angle", vm) { + let data = data.into_iter().map(|value| real_phase_angle(value, deg)).collect(); + let arr = NdArray::new(data, shape, NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let (value, _) = numeric_scalar_info(arg, "numpy.angle", vm)?; + Ok(Value::Float(real_phase_angle(value, deg))) + } +} + +/// Computes the phase angle of a real number, preserving NumPy's `-0.0 -> pi` behavior. +fn real_phase_angle(value: f64, deg: bool) -> f64 { + let angle = if value.is_sign_negative() { PI } else { 0.0 }; + if deg { angle.to_degrees() } else { angle } +} + +/// `numpy.conj(a)` / `numpy.real(a)` for Monty's real-valued ndarray subset. +/// +/// Monty does not currently model complex numbers, so the conjugate and real +/// component are identical to the input. Lists are converted to ndarrays, while +/// numeric scalars keep their scalar shape and dtype. +fn call_real_identity(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, name: &str) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + + if let Ok((data, shape, dtype)) = extract_ndarray_info(arg, name, vm) { + let arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let (value, dtype) = numeric_scalar_info(arg, name, vm)?; + Ok(scalar_from_f64(value, dtype)) + } +} + +/// `numpy.real_if_close(a, tol=100)` — identity for Monty's real-valued subset. +fn call_real_if_close(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg, tol) = args.get_one_two_args("numpy.real_if_close", vm.heap)?; + defer_drop!(arg, vm); + if let Some(tol) = tol { + tol.drop_with_heap(vm); + } + + if let Ok((data, shape, dtype)) = extract_ndarray_info(arg, "numpy.real_if_close", vm) { + let arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let (value, dtype) = numeric_scalar_info(arg, "numpy.real_if_close", vm)?; + Ok(scalar_from_f64(value, dtype)) + } +} + +/// `numpy.imag(a)` for Monty's real-valued ndarray subset. +/// +/// Since complex dtypes are unsupported, every supported numeric input has a +/// zero imaginary component. The result preserves array shape and scalar-vs-array +/// form so common NumPy introspection snippets continue to work. +fn call_imag(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.imag", vm.heap)?; + defer_drop!(arg, vm); + + if let Ok((data, shape, dtype)) = extract_ndarray_info(arg, "numpy.imag", vm) { + let arr = NdArray::new(vec![0.0; data.len()], shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let (_, dtype) = numeric_scalar_info(arg, "numpy.imag", vm)?; + Ok(scalar_from_f64(0.0, dtype)) + } +} + +/// Element-wise `numpy.isreal()` / `numpy.iscomplex()` over real-only inputs. +/// +/// The safe ndarray model has no complex dtype, so every numeric element is real +/// and no numeric element is complex. Non-numeric object arrays remain outside +/// this module's supported surface. +fn call_realness_elementwise( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + is_real: bool, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + + if let Ok((data, shape, _)) = extract_ndarray_info(arg, name, vm) { + let fill = bool_to_f64(is_real); + let arr = NdArray::new(vec![fill; data.len()], shape, NdArrayDtype::Bool); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + numeric_scalar_info(arg, name, vm)?; + Ok(Value::Bool(is_real)) + } +} + +/// Object-level `numpy.isrealobj()` / `numpy.iscomplexobj()`. +/// +/// Monty cannot construct complex arrays or scalars, so these predicates are +/// constant for the current runtime surface. The argument is still consumed and +/// dropped normally to preserve reference-count behavior. +fn call_realness_object( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + is_real: bool, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + arg.drop_with_heap(vm); + Ok(Value::Bool(is_real)) +} + +/// `numpy.isscalar(a)` — report whether a value is scalar in Monty's runtime. +/// +/// Numeric values, strings/bytes, dates, timedeltas, and long integers are +/// scalar-like; containers, arrays, modules, functions, and sentinel values are +/// not. This intentionally avoids invoking user-visible iteration or attribute +/// lookup, so it remains a pure shape/type predicate. +fn call_isscalar(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.isscalar", vm.heap)?; + defer_drop!(arg, vm); + Ok(Value::Bool(is_numpy_scalar(arg, vm))) +} + +/// `numpy.iterable(a)` — report whether Monty's iterator protocol accepts a value. +fn call_iterable(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.iterable", vm.heap)?; + defer_drop!(arg, vm); + Ok(Value::Bool(is_numpy_iterable(arg, vm))) +} + +/// `numpy.dtype(dtype)` — normalize a supported dtype-like value to Monty's dtype marker. +fn call_dtype(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.dtype", vm.heap)?; + defer_drop!(arg, vm); + Ok(dtype_token_value(dtype_token_from_optional_dtype_value( + arg, + "numpy.dtype", + vm, + )?)) +} + +/// `numpy.astype(a, dtype)` — module-level ndarray dtype conversion. +/// +/// This mirrors NumPy's pure helper for Monty's compact ndarray subset. The +/// dtype argument is normalized through the same dtype parser as constructors, +/// so Python type objects, dtype marker attributes, and supported dtype strings +/// all reach the same casting implementation used by `ndarray.astype()`. +fn call_astype(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (array_val, dtype_val) = args.get_two_args("numpy.astype", vm.heap)?; + defer_drop!(array_val, vm); + defer_drop!(dtype_val, vm); + let arr = ndarray_from_value(array_val, "numpy.astype", vm)?; + let dtype = dtype_meta_from_optional_dtype_value(dtype_val, "numpy.astype", vm)?; + arr.astype(compact_dtype_name(dtype), vm.heap) +} + +/// `numpy.can_cast(from_, to)` — safe cast predicate for Monty's compact dtype set. +fn call_can_cast(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (from_val, to_val) = args.get_two_args("numpy.can_cast", vm.heap)?; + defer_drop!(from_val, vm); + defer_drop!(to_val, vm); + let from = dtype_token_from_dtype_value(from_val, "numpy.can_cast", vm)?; + let to = dtype_token_from_dtype_value(to_val, "numpy.can_cast", vm)?; + Ok(Value::Bool(can_cast_dtype_token(from, to))) +} + +/// `numpy.promote_types(type1, type2)` — promoted dtype marker. +fn call_promote_types(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (first_val, second_val) = args.get_two_args("numpy.promote_types", vm.heap)?; + defer_drop!(first_val, vm); + defer_drop!(second_val, vm); + let first = dtype_token_from_dtype_value(first_val, "numpy.promote_types", vm)?; + let second = dtype_token_from_dtype_value(second_val, "numpy.promote_types", vm)?; + Ok(dtype_token_value(promote_dtype_token( + first, + second, + "numpy.promote_types", + )?)) +} + +/// `numpy.result_type(*arrays_and_dtypes)` — result dtype marker for real numeric inputs. +fn call_result_type(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.result_type", vm.heap)?; + defer_drop_mut!(pos, vm); + if pos.len() == 0 { + return Err(ExcType::type_error_at_least("numpy.result_type", 1, 0)); + } + + let mut result = DtypeToken::Compact(CompactDtype::Bool); + for arg in pos.by_ref() { + defer_drop!(arg, vm); + let token = dtype_token_from_value(arg, "numpy.result_type", vm)?; + result = promote_dtype_token(result, token, "numpy.result_type")?; + } + Ok(dtype_token_value(result)) +} + +/// `numpy.common_type(*arrays)` — common real dtype marker, with float64 as minimum. +fn call_common_type(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.common_type", vm.heap)?; + defer_drop_mut!(pos, vm); + if pos.len() == 0 { + return Err(ExcType::type_error_at_least("numpy.common_type", 1, 0)); + } + + for arg in pos.by_ref() { + defer_drop!(arg, vm); + dtype_meta_from_value(arg, "numpy.common_type", vm)?; + } + Ok(dtype_meta_value(CompactDtype::Float64)) +} + +/// `numpy.min_scalar_type(a)` — smallest compatible marker in Monty's compact dtype set. +fn call_min_scalar_type(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.min_scalar_type", vm.heap)?; + defer_drop!(arg, vm); + let value = if let Value::Int(value) = arg { + Value::InternString(min_integer_scalar_marker(*value).into()) + } else { + dtype_meta_value(dtype_meta_from_value(arg, "numpy.min_scalar_type", vm)?) + }; + Ok(value) +} + +/// Returns the narrow integer marker NumPy would choose for a Python int scalar. +fn min_integer_scalar_marker(value: i64) -> StaticStrings { + if value < 0 { + if value >= i64::from(i8::MIN) { + StaticStrings::NpInt8 + } else if value >= i64::from(i16::MIN) { + StaticStrings::NpInt16 + } else if value >= i64::from(i32::MIN) { + StaticStrings::NpInt32 + } else { + StaticStrings::NpInt64 + } + } else if value <= i64::from(u8::MAX) { + StaticStrings::NpUint8 + } else if value <= i64::from(u16::MAX) { + StaticStrings::NpUint16 + } else if value <= i64::from(u32::MAX) { + StaticStrings::NpUint32 + } else { + StaticStrings::NpUint64 + } +} + +/// `numpy.mintypecode(typechars, typeset='GDFgdf', default='d')` — legacy dtype code helper. +fn call_mintypecode(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.mintypecode", vm.heap)?; + defer_drop_mut!(pos, vm); + let typechars_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.mintypecode", 1, 0))?; + defer_drop!(typechars_val, vm); + let typeset = if let Some(typeset_val) = pos.next() { + defer_drop!(typeset_val, vm); + string_from_value(typeset_val, "numpy.mintypecode", vm)? + } else { + "GDFgdf".to_string() + }; + let default = if let Some(default_val) = pos.next() { + defer_drop!(default_val, vm); + string_from_value(default_val, "numpy.mintypecode", vm)? + } else { + "d".to_string() + }; + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.mintypecode", 3, 4)); + } + + let chars = typechars_from_value(typechars_val, "numpy.mintypecode", vm)?; + let result = mintypecode_result(&chars, &typeset, &default); + allocate_string(result.to_string(), vm.heap) +} + +/// `numpy.typename(char)` — human-readable legacy dtype character name. +fn call_typename(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.typename", vm.heap)?; + defer_drop!(arg, vm); + let text = string_from_value(arg, "numpy.typename", vm)?; + let name = match text.as_str() { + "?" => "bool", + "b" => "signed char", + "B" => "unsigned char", + "h" => "short", + "H" => "unsigned short", + "i" => "integer", + "I" => "unsigned integer", + "l" => "long integer", + "L" => "unsigned long integer", + "q" => "long integer", + "Q" => "unsigned long integer", + "e" => "half precision", + "f" => "single precision", + "d" => "double precision", + "g" => "long precision", + "F" => "complex single precision", + "D" => "complex double precision", + "G" => "complex long double precision", + "c" => "character", + _ => { + return Err(SimpleException::new_msg(ExcType::KeyError, format!("'{}'", text.escape_debug())).into()); + } + }; + allocate_string(name.to_string(), vm.heap) +} + +/// `numpy.info(object=None, ...)` — accept documentation queries without host introspection. +fn call_info(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> Value { + args.drop_with_heap(vm); + Value::None +} + +/// `numpy.issubdtype(arg1, arg2)` — check Monty's compact dtype hierarchy. +fn call_issubdtype(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg1, arg2) = args.get_two_args("numpy.issubdtype", vm.heap)?; + defer_drop!(arg1, vm); + defer_drop!(arg2, vm); + let first = dtype_kind_from_value(arg1, "numpy.issubdtype", vm)?; + let second = dtype_kind_from_value(arg2, "numpy.issubdtype", vm)?; + Ok(Value::Bool(is_subdtype_kind(first, second))) +} + +/// `numpy.isdtype(dtype, kind)` — check a dtype against supported Array API kind names. +fn call_isdtype(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (dtype_val, kind_val) = args.get_two_args("numpy.isdtype", vm.heap)?; + defer_drop!(dtype_val, vm); + defer_drop!(kind_val, vm); + let dtype = dtype_token_from_dtype_value(dtype_val, "numpy.isdtype", vm)?; + let kind = isdtype_kind_from_value(kind_val, "numpy.isdtype", vm)?; + Ok(Value::Bool(is_dtype_token_kind(dtype, kind))) +} + +/// `numpy.finfo(dtype)` — floating dtype machine-limit metadata. +fn call_finfo(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.finfo", vm.heap)?; + defer_drop!(arg, vm); + let info = float_info_from_value(arg, vm)?; + allocate_namedtuple_result( + "finfo", + &[ + "bits", + "eps", + "epsneg", + "max", + "min", + "tiny", + "smallest_normal", + "smallest_subnormal", + "resolution", + "precision", + "nmant", + "iexp", + "machep", + "negep", + "dtype", + ], + vec![ + Value::Int(info.bits), + Value::Float(info.eps), + Value::Float(info.epsneg), + Value::Float(info.max), + Value::Float(info.min), + Value::Float(info.tiny), + Value::Float(info.tiny), + Value::Float(info.smallest_subnormal), + Value::Float(info.resolution), + Value::Int(info.precision), + Value::Int(info.nmant), + Value::Int(info.iexp), + Value::Int(info.machep), + Value::Int(info.negep), + Value::InternString(info.dtype_marker.into()), + ], + vm, + ) +} + +/// `numpy.iinfo(dtype)` — integer dtype machine-limit metadata. +fn call_iinfo(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.iinfo", vm.heap)?; + defer_drop!(arg, vm); + let info = integer_info_from_value(arg, vm)?; + allocate_namedtuple_result( + "iinfo", + &["min", "max", "bits", "dtype"], + vec![ + integer_limit_to_value(info.min, vm.heap)?, + integer_limit_to_value(info.max, vm.heap)?, + Value::Int(info.bits), + Value::InternString(info.dtype_marker.into()), + ], + vm, + ) +} + +/// `numpy.geterr()` — return Monty's fixed floating-point error policy. +fn call_geterr(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.check_zero_args("numpy.geterr", vm.heap)?; + numpy_error_policy_dict(vm) +} + +/// `numpy.seterr(...)` — accept error-policy options and return the previous policy. +fn call_seterr(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.drop_with_heap(vm); + numpy_error_policy_dict(vm) +} + +/// `numpy.geterrcall()` — return the fixed absence of an error callback. +fn call_geterrcall(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.check_zero_args("numpy.geterrcall", vm.heap)?; + Ok(Value::None) +} + +/// `numpy.seterrcall(callback)` — accept callback configuration as a no-op. +fn call_seterrcall(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let callback = args.get_one_arg("numpy.seterrcall", vm.heap)?; + callback.drop_with_heap(vm); + Ok(Value::None) +} + +/// `numpy.errstate(...)` — lightweight placeholder for context-manager-style code. +fn call_errstate(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.drop_with_heap(vm); + numpy_error_policy_dict(vm) +} + +/// `numpy.get_printoptions()` — return Monty's fixed print-option snapshot. +fn call_get_printoptions(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.check_zero_args("numpy.get_printoptions", vm.heap)?; + numpy_print_options_dict(vm) +} + +/// `numpy.set_printoptions(...)` — accept print options as a no-op. +fn call_set_printoptions(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> Value { + args.drop_with_heap(vm); + Value::None +} + +/// `numpy.printoptions(...)` — lightweight placeholder for context-manager-style code. +fn call_printoptions(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.drop_with_heap(vm); + numpy_print_options_dict(vm) +} + +/// `numpy.getbufsize()` — return NumPy's legacy default buffer size. +fn call_getbufsize(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + args.check_zero_args("numpy.getbufsize", vm.heap)?; + Ok(Value::Int(8192)) +} + +/// `numpy.setbufsize(size)` — accept a buffer size as a no-op and return the previous size. +fn call_setbufsize(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.setbufsize", vm.heap)?; + arg.drop_with_heap(vm); + Ok(Value::Int(8192)) +} + +/// `numpy.show_runtime()` — no-op placeholder that avoids host runtime introspection. +fn call_show_runtime(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> Value { + args.drop_with_heap(vm); + Value::None +} + +/// `numpy.test()` — no-op placeholder that avoids launching NumPy's external test suite. +fn call_test(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> Value { + args.drop_with_heap(vm); + Value::None +} + +/// Builds the fixed floating-point error-policy dictionary. +fn numpy_error_policy_dict(vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + string_dict_from_pairs( + &[ + ("divide", "warn"), + ("over", "warn"), + ("under", "ignore"), + ("invalid", "warn"), + ], + vm, + ) +} + +/// Builds the fixed print-options dictionary for Monty's ndarray representation. +fn numpy_print_options_dict(vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let pairs = vec![ + (allocate_string("edgeitems".to_string(), vm.heap)?, Value::Int(3)), + (allocate_string("threshold".to_string(), vm.heap)?, Value::Int(1000)), + (allocate_string("linewidth".to_string(), vm.heap)?, Value::Int(75)), + (allocate_string("precision".to_string(), vm.heap)?, Value::Int(8)), + (allocate_string("suppress".to_string(), vm.heap)?, Value::Bool(false)), + ( + allocate_string("nanstr".to_string(), vm.heap)?, + allocate_string("nan".to_string(), vm.heap)?, + ), + ( + allocate_string("infstr".to_string(), vm.heap)?, + allocate_string("inf".to_string(), vm.heap)?, + ), + ( + allocate_string("sign".to_string(), vm.heap)?, + allocate_string("-".to_string(), vm.heap)?, + ), + ( + allocate_string("floatmode".to_string(), vm.heap)?, + allocate_string("maxprec".to_string(), vm.heap)?, + ), + (allocate_string("legacy".to_string(), vm.heap)?, Value::None), + ]; + dict_from_pairs(pairs, vm) +} + +/// Allocates a Python dict from string key/value pairs. +fn string_dict_from_pairs(pairs: &[(&str, &str)], vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let mut values = Vec::with_capacity(pairs.len()); + for (key, value) in pairs { + values.push(( + allocate_string((*key).to_string(), vm.heap)?, + allocate_string((*value).to_string(), vm.heap)?, + )); + } + dict_from_pairs(values, vm) +} + +/// Allocates a Python dict from already-owned key/value pairs. +fn dict_from_pairs(pairs: Vec<(Value, Value)>, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let dict = Dict::from_pairs(pairs, vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::Dict(dict))?)) +} + +/// Compact dtype categories that fit Monty's bool/int/float ndarray storage model. +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +enum CompactDtype { + /// Boolean arrays and scalar markers. + Bool, + /// Integer arrays and scalar markers. + Int, + /// Single-precision dtype marker accepted as a float storage alias. + Float32, + /// Double-precision dtype marker used by Monty's float arrays and scalars. + Float64, +} + +/// Dtype marker metadata recognized without necessarily supporting array storage. +#[derive(Clone, Copy, PartialEq, Eq)] +enum DtypeToken { + /// A dtype backed by Monty's compact bool/int/float ndarray storage. + Compact(CompactDtype), + /// Complex single-precision metadata marker. + Complex64, + /// Complex double-precision metadata marker. + Complex128, + /// Unicode string metadata marker. + Str, + /// Byte-string metadata marker. + Bytes, + /// Void/flexible record metadata marker. + Void, + /// Python object metadata marker. + Object, + /// Datetime metadata marker. + DateTime64, + /// Timedelta metadata marker. + Timedelta64, +} + +/// Dtype category markers exposed for hierarchy-style predicates. +#[derive(Clone, Copy, PartialEq, Eq)] +enum DtypeCategory { + /// Generic category containing every dtype Monty can currently recognize. + Generic, + /// Boolean dtype category. + Bool, + /// Signed or unsigned integer dtype category. + Integer, + /// Signed integer dtype category. + SignedInteger, + /// Unsigned integer dtype category. + UnsignedInteger, + /// Floating dtype category. + Floating, + /// Inexact numeric category; complex is unsupported, so this means floating. + Inexact, + /// Numeric category for integer and floating dtypes. + Number, + /// Complex category marker, currently empty for Monty's concrete dtypes. + ComplexFloating, + /// Flexible-width category marker, currently empty for Monty's concrete dtypes. + Flexible, + /// Character category marker, currently empty for Monty's concrete dtypes. + Character, +} + +/// Concrete dtype family used by `issubdtype`. +#[derive(Clone, Copy, PartialEq, Eq)] +enum DtypeConcrete { + /// Boolean dtype. + Bool, + /// Signed integer dtype. + SignedInteger, + /// Unsigned integer dtype. + UnsignedInteger, + /// Real floating dtype. + Floating, + /// Complex floating dtype. + ComplexFloating, + /// String/bytes dtype. + Character, + /// Flexible non-character dtype such as `void`. + FlexibleVoid, + /// Object dtype. + Object, + /// Datetime dtype. + DateTime64, + /// Timedelta dtype. + Timedelta64, +} + +/// Parsed dtype hierarchy operand for `issubdtype`. +#[derive(Clone, Copy)] +enum DtypeKind { + /// Concrete compact dtype such as `np.int64`. + Concrete(DtypeConcrete), + /// Category marker such as `np.integer`. + Category(DtypeCategory), +} + +/// Parsed second operand for `isdtype`. +#[derive(Clone, Copy)] +enum IsdtypeKind { + /// Concrete dtype marker requiring exact equality. + Concrete(CompactDtype), + /// Named kind string such as `"integral"` or `"real floating"`. + Category(DtypeCategory), + /// Recognized unsupported kind that is always false for Monty's dtype set. + Never, +} + +/// Machine-limit metadata returned by Monty's `numpy.finfo()` subset. +struct FloatInfo { + /// Interned dtype name marker, such as `float32` or `float64`. + dtype_marker: StaticStrings, + /// Number of storage bits for the dtype. + bits: i64, + /// Difference between 1.0 and the next representable value above it. + eps: f64, + /// Difference between 1.0 and the next representable value below it. + epsneg: f64, + /// Largest finite representable value. + max: f64, + /// Most negative finite representable value. + min: f64, + /// Smallest positive normal representable value. + tiny: f64, + /// Smallest positive subnormal representable value. + smallest_subnormal: f64, + /// Decimal resolution reported by NumPy. + resolution: f64, + /// Approximate decimal precision. + precision: i64, + /// Number of mantissa bits reported by NumPy. + nmant: i64, + /// Number of exponent bits reported by NumPy. + iexp: i64, + /// Exponent of `eps`. + machep: i64, + /// Exponent of `epsneg`. + negep: i64, +} + +/// Machine-limit metadata returned by Monty's `numpy.iinfo()` subset. +struct IntegerInfo { + /// Interned dtype name marker, such as `int16` or `uint32`. + dtype_marker: StaticStrings, + /// Number of storage bits for the dtype. + bits: i64, + /// Minimum representable integer. + min: IntegerLimit, + /// Maximum representable integer. + max: IntegerLimit, +} + +/// Signed or unsigned integer boundary that can exceed Monty's fast `i64` path. +#[derive(Clone, Copy)] +enum IntegerLimit { + /// Signed integer bound, including all supported signed minima. + Signed(i128), + /// Unsigned integer bound, used for `uint64::max`. + Unsigned(u128), +} + +/// Parses a dtype argument such as `np.float64` or `'int64'`. +fn dtype_meta_from_dtype_value( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + let text = string_from_value(value, name, vm)?; + dtype_meta_from_str(&text, name) +} + +/// Parses dtype metadata, including marker-only names that have no ndarray storage. +fn dtype_token_from_dtype_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let text = string_from_value(value, name, vm)?; + dtype_token_from_str(&text, name) +} + +/// Parses dtype metadata from APIs that accept Python type constructors and `None`. +fn dtype_token_from_optional_dtype_value( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + match value { + Value::None => Ok(DtypeToken::Compact(CompactDtype::Float64)), + Value::Builtin(Builtins::Type(Type::Bool)) => Ok(DtypeToken::Compact(CompactDtype::Bool)), + Value::Builtin(Builtins::Type(Type::Int)) => Ok(DtypeToken::Compact(CompactDtype::Int)), + Value::Builtin(Builtins::Type(Type::Float)) => Ok(DtypeToken::Compact(CompactDtype::Float64)), + _ => dtype_token_from_dtype_value(value, name, vm), + } +} + +/// Parses a concrete dtype or Monty category marker for `issubdtype`. +fn dtype_kind_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let text = string_from_value(value, name, vm)?; + if let Some(category) = dtype_category_from_marker_text(&text) { + Ok(DtypeKind::Category(category)) + } else { + Ok(DtypeKind::Concrete(dtype_concrete_from_str(&text, name)?)) + } +} + +/// Parses the `kind` operand accepted by Monty's `isdtype` subset. +fn isdtype_kind_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let text = string_from_value(value, name, vm)?; + if dtype_category_from_marker_text(&text).is_some() { + Ok(IsdtypeKind::Never) + } else if let Some(category) = isdtype_category_from_name(&text) { + Ok(IsdtypeKind::Category(category)) + } else if matches!(text.as_str(), "unsigned integer") { + Ok(IsdtypeKind::Never) + } else if let Ok(dtype) = dtype_meta_from_str(&text, name) { + Ok(IsdtypeKind::Concrete(dtype)) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("kind argument is a string, but '{text}' is not a known kind name."), + ) + .into()) + } +} + +/// Infers the compact dtype for a dtype marker, scalar, ndarray, or list. +fn dtype_meta_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + if let Ok(text) = string_from_value(value, name, vm) { + dtype_meta_from_str(&text, name) + } else if let Ok((_, dtype)) = numeric_scalar_info(value, name, vm) { + Ok(dtype_meta_from_ndarray_dtype(dtype)) + } else { + let arr = ndarray_from_value(value, name, vm)?; + Ok(dtype_meta_from_ndarray_dtype(arr.dtype())) + } +} + +/// Infers dtype metadata for dtype markers, scalar values, ndarrays, or lists. +fn dtype_token_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + if let Ok(text) = string_from_value(value, name, vm) { + dtype_token_from_str(&text, name) + } else if let Ok((_, dtype)) = numeric_scalar_info(value, name, vm) { + Ok(DtypeToken::Compact(dtype_meta_from_ndarray_dtype(dtype))) + } else { + let arr = ndarray_from_value(value, name, vm)?; + Ok(DtypeToken::Compact(dtype_meta_from_ndarray_dtype(arr.dtype()))) + } +} + +/// Resolves the dtype-like input accepted by Monty's `numpy.finfo()` subset. +fn float_info_from_value(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::Float(_) | Value::Builtin(Builtins::Type(Type::Float)) => { + Ok(float_info_from_kind(FloatInfoKind::Float64)) + } + Value::Int(_) | Value::Bool(_) | Value::Builtin(Builtins::Type(Type::Int | Type::Bool)) => { + Err(finfo_not_inexact_error("int64")) + } + _ => { + let text = string_from_value(value, "numpy.finfo", vm)?; + float_info_from_str(&text).ok_or_else(|| finfo_not_inexact_error(&text)) + } + } +} + +/// Resolves the dtype-like input accepted by Monty's `numpy.iinfo()` subset. +fn integer_info_from_value(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::Int(_) | Value::Builtin(Builtins::Type(Type::Int)) => { + Ok(integer_info_signed(64, StaticStrings::NpInt64)) + } + Value::Bool(_) | Value::Builtin(Builtins::Type(Type::Bool)) => Err(invalid_integer_dtype_error("b")), + Value::Float(_) | Value::Builtin(Builtins::Type(Type::Float)) => Err(invalid_integer_dtype_error("f")), + _ => { + let text = string_from_value(value, "numpy.iinfo", vm)?; + integer_info_from_str(&text).ok_or_else(|| invalid_integer_dtype_from_text(&text)) + } + } +} + +/// Supported floating metadata widths for `finfo()`. +#[derive(Clone, Copy)] +enum FloatInfoKind { + /// IEEE-754 binary16 metadata. + Float16, + /// IEEE-754 binary32 metadata. + Float32, + /// IEEE-754 binary64 metadata, also used for Monty's `longdouble` alias. + Float64, +} + +/// Parses supported floating dtype names for `finfo()`. +fn float_info_from_str(text: &str) -> Option { + let kind = match text { + "float16" | "half" | "e" => FloatInfoKind::Float16, + "float32" | "single" | "f" => FloatInfoKind::Float32, + "float64" | "double" | "longdouble" | "float" | "d" | "g" => FloatInfoKind::Float64, + _ => return None, + }; + Some(float_info_from_kind(kind)) +} + +/// Returns static `finfo()` metadata for the selected supported float width. +fn float_info_from_kind(kind: FloatInfoKind) -> FloatInfo { + match kind { + FloatInfoKind::Float16 => FloatInfo { + dtype_marker: StaticStrings::NpFloat16, + bits: 16, + eps: 0.000_976_562_5, + epsneg: 0.000_488_281_25, + max: 65_504.0, + min: -65_504.0, + tiny: 0.000_061_035_156_25, + smallest_subnormal: 0.000_000_059_604_644_775_390_63, + resolution: 0.001, + precision: 3, + nmant: 10, + iexp: 5, + machep: -10, + negep: -11, + }, + FloatInfoKind::Float32 => FloatInfo { + dtype_marker: StaticStrings::NpFloat32, + bits: 32, + eps: 0.000_000_119_209_289_550_781_25, + epsneg: 0.000_000_059_604_644_775_390_63, + max: 3.402_823_466_385_288_6e38, + min: -3.402_823_466_385_288_6e38, + tiny: 1.175_494_350_822_287_5e-38, + smallest_subnormal: 1.401_298_464_324_817e-45, + resolution: 0.000_001, + precision: 6, + nmant: 23, + iexp: 8, + machep: -23, + negep: -24, + }, + FloatInfoKind::Float64 => FloatInfo { + dtype_marker: StaticStrings::NpFloat64, + bits: 64, + eps: f64::EPSILON, + epsneg: f64::EPSILON / 2.0, + max: f64::MAX, + min: -f64::MAX, + tiny: f64::MIN_POSITIVE, + smallest_subnormal: 5e-324, + resolution: 1e-15, + precision: 15, + nmant: 52, + iexp: 11, + machep: -52, + negep: -53, + }, + } +} + +/// Parses supported integer dtype names for `iinfo()`. +fn integer_info_from_str(text: &str) -> Option { + match text { + "int8" | "byte" | "b" => Some(integer_info_signed(8, StaticStrings::NpInt8)), + "int16" | "short" | "h" => Some(integer_info_signed(16, StaticStrings::NpInt16)), + "int32" | "intc" | "i" => Some(integer_info_signed(32, StaticStrings::NpInt32)), + "int64" | "int_" | "intp" | "long" | "longlong" | "l" | "q" => { + Some(integer_info_signed(64, StaticStrings::NpInt64)) + } + "uint8" | "ubyte" | "B" => Some(integer_info_unsigned(8, StaticStrings::NpUint8)), + "uint16" | "ushort" | "H" => Some(integer_info_unsigned(16, StaticStrings::NpUint16)), + "uint32" | "uintc" | "I" => Some(integer_info_unsigned(32, StaticStrings::NpUint32)), + "uint64" | "uint" | "uintp" | "ulong" | "ulonglong" | "L" | "Q" => { + Some(integer_info_unsigned(64, StaticStrings::NpUint64)) + } + _ => None, + } +} + +/// Builds signed integer metadata for a fixed-width two's-complement dtype. +fn integer_info_signed(bits: u32, dtype_marker: StaticStrings) -> IntegerInfo { + let max = (1_i128 << (bits - 1)) - 1; + let min = -(1_i128 << (bits - 1)); + IntegerInfo { + dtype_marker, + bits: i64::from(bits), + min: IntegerLimit::Signed(min), + max: IntegerLimit::Signed(max), + } +} + +/// Builds unsigned integer metadata for a fixed-width dtype. +fn integer_info_unsigned(bits: u32, dtype_marker: StaticStrings) -> IntegerInfo { + let max = (1_u128 << bits) - 1; + IntegerInfo { + dtype_marker, + bits: i64::from(bits), + min: IntegerLimit::Unsigned(0), + max: IntegerLimit::Unsigned(max), + } +} + +/// Converts an integer metadata boundary into Monty's fast or arbitrary-precision int value. +fn integer_limit_to_value(limit: IntegerLimit, heap: &Heap) -> RunResult { + let value = match limit { + IntegerLimit::Signed(value) => BigInt::from(value), + IntegerLimit::Unsigned(value) => BigInt::from(value), + }; + Ok(LongInt::new(value).into_value(heap)?) +} + +/// Creates NumPy-style `finfo()` errors for non-floating dtype inputs. +fn finfo_not_inexact_error(text: &str) -> RunError { + SimpleException::new_msg( + ExcType::ValueError, + format!("data type not inexact"), + ) + .into() +} + +/// Creates NumPy-style `iinfo()` errors for non-integer dtype inputs. +fn invalid_integer_dtype_error(kind: &str) -> RunError { + SimpleException::new_msg(ExcType::ValueError, format!("Invalid integer data type '{kind}'.")).into() +} + +/// Maps unsupported dtype text onto a compact `iinfo()` error category. +fn invalid_integer_dtype_from_text(text: &str) -> RunError { + let kind = match text { + "bool" | "bool_" | "?" => "b", + "float16" | "float32" | "float64" | "half" | "single" | "double" | "longdouble" | "float" | "e" | "f" | "d" + | "g" => "f", + _ => text, + }; + invalid_integer_dtype_error(kind) +} + +/// Maps dtype text onto Monty's compact dtype categories. +fn dtype_meta_from_str(text: &str, name: &str) -> RunResult { + match text { + "bool" | "bool_" | "?" => Ok(CompactDtype::Bool), + "int8" | "int16" | "int32" | "int64" | "int_" | "intc" | "intp" | "long" | "longlong" | "byte" | "short" + | "uint8" | "uint16" | "uint32" | "uint64" | "uint" | "uintc" | "uintp" | "ubyte" | "ushort" | "ulong" + | "ulonglong" | "i" | "l" | "q" | "b" | "h" | "B" | "H" | "I" | "L" | "Q" => Ok(CompactDtype::Int), + "float16" | "float32" | "half" | "single" | "f" | "e" => Ok(CompactDtype::Float32), + "float64" | "double" | "longdouble" | "float" | "d" | "g" => Ok(CompactDtype::Float64), + _ => Err(ExcType::type_error(format!("{name}() unsupported dtype: {text}"))), + } +} + +/// Maps dtype text onto metadata tokens, including unsupported-storage markers. +fn dtype_token_from_str(text: &str, name: &str) -> RunResult { + if let Ok(dtype) = dtype_meta_from_str(text, name) { + Ok(DtypeToken::Compact(dtype)) + } else { + match text { + "complex64" | "csingle" | "F" => Ok(DtypeToken::Complex64), + "complex128" | "cdouble" | "complex" | "D" => Ok(DtypeToken::Complex128), + "clongdouble" | "G" => Ok(DtypeToken::Complex128), + "str" | "str_" | "U" => Ok(DtypeToken::Str), + "bytes" | "bytes_" | "S" => Ok(DtypeToken::Bytes), + "void" | "V" => Ok(DtypeToken::Void), + "object" | "object_" | "O" => Ok(DtypeToken::Object), + "datetime64" | "M" => Ok(DtypeToken::DateTime64), + "timedelta64" | "m" => Ok(DtypeToken::Timedelta64), + _ => Err(ExcType::type_error(format!("{name}() unsupported dtype: {text}"))), + } + } +} + +/// Maps dtype text onto the concrete family used by hierarchy predicates. +fn dtype_concrete_from_str(text: &str, name: &str) -> RunResult { + match text { + "bool" | "bool_" | "?" => Ok(DtypeConcrete::Bool), + "int8" | "int16" | "int32" | "int64" | "int_" | "intc" | "intp" | "long" | "longlong" | "byte" | "short" + | "i" | "l" | "q" | "b" | "h" => Ok(DtypeConcrete::SignedInteger), + "uint8" | "uint16" | "uint32" | "uint64" | "uint" | "uintc" | "uintp" | "ubyte" | "ushort" | "ulong" + | "ulonglong" | "B" | "H" | "I" | "L" | "Q" => Ok(DtypeConcrete::UnsignedInteger), + "float16" | "float32" | "float64" | "half" | "single" | "double" | "longdouble" | "float" | "f" | "e" | "d" + | "g" => Ok(DtypeConcrete::Floating), + "complex64" | "complex128" | "cdouble" | "csingle" | "clongdouble" | "complex" | "F" | "D" | "G" => { + Ok(DtypeConcrete::ComplexFloating) + } + "str" | "str_" | "bytes" | "bytes_" | "U" | "S" => Ok(DtypeConcrete::Character), + "void" | "V" => Ok(DtypeConcrete::FlexibleVoid), + "object" | "object_" | "O" => Ok(DtypeConcrete::Object), + "datetime64" | "M" => Ok(DtypeConcrete::DateTime64), + "timedelta64" | "m" => Ok(DtypeConcrete::Timedelta64), + _ => Err(ExcType::type_error(format!("{name}() unsupported dtype: {text}"))), + } +} + +/// Converts an ndarray dtype into a compact dtype category. +fn dtype_meta_from_ndarray_dtype(dtype: NdArrayDtype) -> CompactDtype { + match dtype { + NdArrayDtype::Bool => CompactDtype::Bool, + NdArrayDtype::Int64 => CompactDtype::Int, + NdArrayDtype::Float64 => CompactDtype::Float64, + } +} + +/// Returns an interned dtype marker for a compact dtype category. +fn dtype_meta_value(dtype: CompactDtype) -> Value { + let marker = match dtype { + CompactDtype::Bool => StaticStrings::NpBool_, + CompactDtype::Int => StaticStrings::NpInt64, + CompactDtype::Float32 => StaticStrings::NpFloat32, + CompactDtype::Float64 => StaticStrings::NpFloat64, + }; + Value::InternString(marker.into()) +} + +/// Returns an interned dtype marker for metadata-only dtype parsing. +fn dtype_token_value(dtype: DtypeToken) -> Value { + let marker = match dtype { + DtypeToken::Compact(dtype) => { + return dtype_meta_value(dtype); + } + DtypeToken::Complex64 => StaticStrings::NpComplex64, + DtypeToken::Complex128 => StaticStrings::NpComplex128, + DtypeToken::Str => StaticStrings::NpStr_, + DtypeToken::Bytes => StaticStrings::NpBytes_, + DtypeToken::Void => StaticStrings::NpVoid, + DtypeToken::Object => StaticStrings::NpObject_, + DtypeToken::DateTime64 => StaticStrings::NpDatetime64, + DtypeToken::Timedelta64 => StaticStrings::NpTimedelta64, + }; + Value::InternString(marker.into()) +} + +/// Returns the canonical dtype text accepted by `NdArray::astype()`. +fn compact_dtype_name(dtype: CompactDtype) -> &'static str { + match dtype { + CompactDtype::Bool => "bool", + CompactDtype::Int => "int64", + CompactDtype::Float32 => "float32", + CompactDtype::Float64 => "float64", + } +} + +/// Promotes two compact dtype categories using NumPy's real numeric ordering. +fn promote_dtype_meta(first: CompactDtype, second: CompactDtype) -> CompactDtype { + first.max(second) +} + +/// Returns whether a cast is safe in the compact bool -> int -> float ordering. +fn can_cast_dtype_meta(from: CompactDtype, to: CompactDtype) -> bool { + from <= to +} + +/// Returns whether NumPy metadata-only casting is safe for supported markers. +fn can_cast_dtype_token(from: DtypeToken, to: DtypeToken) -> bool { + match (from, to) { + (DtypeToken::Compact(from), DtypeToken::Compact(to)) => can_cast_dtype_meta(from, to), + (_, DtypeToken::Object) => true, + (DtypeToken::Object, _) => false, + (DtypeToken::Compact(_), DtypeToken::Complex64 | DtypeToken::Complex128) => true, + (DtypeToken::Compact(_), DtypeToken::Str | DtypeToken::Bytes) => true, + (DtypeToken::Complex64, DtypeToken::Complex64 | DtypeToken::Complex128 | DtypeToken::Str) => true, + (DtypeToken::Complex128, DtypeToken::Complex128 | DtypeToken::Str) => true, + (DtypeToken::Str, DtypeToken::Str) | (DtypeToken::Bytes, DtypeToken::Bytes | DtypeToken::Str) => true, + (DtypeToken::Void, DtypeToken::Void) + | (DtypeToken::DateTime64, DtypeToken::DateTime64) + | (DtypeToken::Timedelta64, DtypeToken::Timedelta64) => true, + _ => false, + } +} + +/// Promotes dtype metadata tokens without enabling unsupported ndarray storage. +fn promote_dtype_token(first: DtypeToken, second: DtypeToken, name: &str) -> RunResult { + let promoted = match (first, second) { + (DtypeToken::Compact(first), DtypeToken::Compact(second)) => { + DtypeToken::Compact(promote_dtype_meta(first, second)) + } + (DtypeToken::Object, _) | (_, DtypeToken::Object) => DtypeToken::Object, + (DtypeToken::Complex128, _) | (_, DtypeToken::Complex128) => DtypeToken::Complex128, + ( + DtypeToken::Complex64, + DtypeToken::Compact(CompactDtype::Float32 | CompactDtype::Bool) | DtypeToken::Complex64, + ) + | (DtypeToken::Compact(CompactDtype::Float32 | CompactDtype::Bool), DtypeToken::Complex64) => { + DtypeToken::Complex64 + } + (DtypeToken::Complex64, DtypeToken::Compact(_)) | (DtypeToken::Compact(_), DtypeToken::Complex64) => { + DtypeToken::Complex128 + } + (DtypeToken::Str, _) | (_, DtypeToken::Str) => DtypeToken::Str, + (DtypeToken::Bytes, _) | (_, DtypeToken::Bytes) => DtypeToken::Bytes, + ( + DtypeToken::Timedelta64, + DtypeToken::Compact(CompactDtype::Bool | CompactDtype::Int) | DtypeToken::Timedelta64, + ) + | (DtypeToken::Compact(CompactDtype::Bool | CompactDtype::Int), DtypeToken::Timedelta64) => { + DtypeToken::Timedelta64 + } + (left, right) if left == right => left, + _ => { + return Err(ExcType::type_error(format!( + "{name}() unsupported dtype promotion for metadata-only dtypes" + ))); + } + }; + Ok(promoted) +} + +/// Returns true when the first dtype/category is included in the second. +fn is_subdtype_kind(first: DtypeKind, second: DtypeKind) -> bool { + match (first, second) { + (DtypeKind::Concrete(lhs), DtypeKind::Concrete(rhs)) => lhs == rhs, + (DtypeKind::Concrete(dtype), DtypeKind::Category(category)) => dtype_category_contains_dtype(category, dtype), + (DtypeKind::Category(lhs), DtypeKind::Category(rhs)) => dtype_category_contains_category(rhs, lhs), + (DtypeKind::Category(_), DtypeKind::Concrete(_)) => false, + } +} + +/// Returns true when a concrete dtype matches an `isdtype` kind operand. +fn is_dtype_kind(dtype: CompactDtype, kind: IsdtypeKind) -> bool { + match kind { + IsdtypeKind::Concrete(kind_dtype) => dtype == kind_dtype, + IsdtypeKind::Category(category) => dtype_category_contains_compact(category, dtype), + IsdtypeKind::Never => false, + } +} + +/// Returns true when a dtype metadata marker matches an `isdtype` kind operand. +fn is_dtype_token_kind(dtype: DtypeToken, kind: IsdtypeKind) -> bool { + match dtype { + DtypeToken::Compact(dtype) => is_dtype_kind(dtype, kind), + DtypeToken::Complex64 | DtypeToken::Complex128 => { + matches!( + kind, + IsdtypeKind::Category(DtypeCategory::ComplexFloating | DtypeCategory::Number) + ) + } + DtypeToken::Str + | DtypeToken::Bytes + | DtypeToken::Void + | DtypeToken::Object + | DtypeToken::DateTime64 + | DtypeToken::Timedelta64 => false, + } +} + +/// Maps Monty's hidden category marker strings back to dtype categories. +fn dtype_category_from_marker_text(text: &str) -> Option { + match text { + "__monty_numpy_generic_category" => Some(DtypeCategory::Generic), + "__monty_numpy_integer_category" => Some(DtypeCategory::Integer), + "__monty_numpy_signedinteger_category" => Some(DtypeCategory::SignedInteger), + "__monty_numpy_unsignedinteger_category" => Some(DtypeCategory::UnsignedInteger), + "__monty_numpy_floating_category" => Some(DtypeCategory::Floating), + "__monty_numpy_inexact_category" => Some(DtypeCategory::Inexact), + "__monty_numpy_number_category" => Some(DtypeCategory::Number), + "__monty_numpy_complexfloating_category" => Some(DtypeCategory::ComplexFloating), + "__monty_numpy_flexible_category" => Some(DtypeCategory::Flexible), + "__monty_numpy_character_category" => Some(DtypeCategory::Character), + _ => None, + } +} + +/// Maps supported Array API dtype kind names to Monty's compact categories. +fn isdtype_category_from_name(text: &str) -> Option { + match text { + "bool" => Some(DtypeCategory::Bool), + "integral" | "signed integer" => Some(DtypeCategory::Integer), + "real floating" => Some(DtypeCategory::Floating), + "complex floating" => Some(DtypeCategory::ComplexFloating), + "numeric" => Some(DtypeCategory::Number), + _ => None, + } +} + +/// Returns whether a dtype category includes a concrete compact dtype. +fn dtype_category_contains_dtype(category: DtypeCategory, dtype: DtypeConcrete) -> bool { + match category { + DtypeCategory::Generic => true, + DtypeCategory::Bool => matches!(dtype, DtypeConcrete::Bool), + DtypeCategory::Integer => matches!(dtype, DtypeConcrete::SignedInteger | DtypeConcrete::UnsignedInteger), + DtypeCategory::SignedInteger => matches!(dtype, DtypeConcrete::SignedInteger), + DtypeCategory::UnsignedInteger => matches!(dtype, DtypeConcrete::UnsignedInteger), + DtypeCategory::Floating => matches!(dtype, DtypeConcrete::Floating), + DtypeCategory::Inexact => matches!(dtype, DtypeConcrete::Floating | DtypeConcrete::ComplexFloating), + DtypeCategory::ComplexFloating => matches!(dtype, DtypeConcrete::ComplexFloating), + DtypeCategory::Flexible => matches!(dtype, DtypeConcrete::Character | DtypeConcrete::FlexibleVoid), + DtypeCategory::Character => matches!(dtype, DtypeConcrete::Character), + DtypeCategory::Number => { + matches!( + dtype, + DtypeConcrete::SignedInteger + | DtypeConcrete::UnsignedInteger + | DtypeConcrete::Floating + | DtypeConcrete::ComplexFloating + | DtypeConcrete::Timedelta64 + ) + } + } +} + +/// Returns whether a dtype category includes one of Monty's compact storage dtypes. +fn dtype_category_contains_compact(category: DtypeCategory, dtype: CompactDtype) -> bool { + let dtype = match dtype { + CompactDtype::Bool => DtypeConcrete::Bool, + CompactDtype::Int => DtypeConcrete::SignedInteger, + CompactDtype::Float32 | CompactDtype::Float64 => DtypeConcrete::Floating, + }; + dtype_category_contains_dtype(category, dtype) +} + +/// Returns whether a dtype category includes another category. +fn dtype_category_contains_category(container: DtypeCategory, member: DtypeCategory) -> bool { + match container { + DtypeCategory::Generic => true, + DtypeCategory::Bool => member == DtypeCategory::Bool, + DtypeCategory::Integer => matches!( + member, + DtypeCategory::Integer | DtypeCategory::SignedInteger | DtypeCategory::UnsignedInteger + ), + DtypeCategory::SignedInteger => member == DtypeCategory::SignedInteger, + DtypeCategory::UnsignedInteger => member == DtypeCategory::UnsignedInteger, + DtypeCategory::Floating => member == DtypeCategory::Floating, + DtypeCategory::Inexact => matches!( + member, + DtypeCategory::Floating | DtypeCategory::ComplexFloating | DtypeCategory::Inexact + ), + DtypeCategory::Number => matches!( + member, + DtypeCategory::Integer + | DtypeCategory::SignedInteger + | DtypeCategory::UnsignedInteger + | DtypeCategory::Floating + | DtypeCategory::Inexact + | DtypeCategory::ComplexFloating + | DtypeCategory::Number + ), + DtypeCategory::ComplexFloating => member == DtypeCategory::ComplexFloating, + DtypeCategory::Flexible => matches!(member, DtypeCategory::Flexible | DtypeCategory::Character), + DtypeCategory::Character => member == DtypeCategory::Character, + } +} + +/// Extracts an owned Python string from interned or heap string values. +fn string_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::InternString(id) => Ok(vm.interns.get_str(*id).to_string()), + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::Str(text) => Ok(text.as_str().to_string()), + _ => Err(ExcType::type_error(format!("{name}() expected a string"))), + }, + _ => Err(ExcType::type_error(format!("{name}() expected a string"))), + } +} + +/// Extracts legacy dtype character codes from a string or sequence of strings. +fn typechars_from_value(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult> { + if let Ok(text) = string_from_value(value, name, vm) { + Ok(text.chars().collect()) + } else { + let items = match value { + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => list.as_slice(), + HeapData::Tuple(tuple) => tuple.as_slice(), + _ => return Err(ExcType::type_error(format!("{name}() expected a string or sequence"))), + }, + _ => return Err(ExcType::type_error(format!("{name}() expected a string or sequence"))), + }; + let mut chars = Vec::new(); + for item in items { + let text = string_from_value(item, name, vm)?; + chars.extend(text.chars()); + } + Ok(chars) + } +} + +/// Chooses the minimal type code present in `typeset`, falling back to `default`. +fn mintypecode_result(chars: &[char], typeset: &str, default: &str) -> char { + let priority = ['G', 'D', 'F', 'g', 'd', 'f']; + priority + .iter() + .copied() + .find(|code| chars.contains(code) && typeset.contains(*code)) + .or_else(|| default.chars().next()) + .unwrap_or('d') +} + +/// Returns whether a value should be treated as scalar by `numpy.isscalar()`. +fn is_numpy_scalar(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> bool { + match value { + Value::Int(_) + | Value::Float(_) + | Value::Bool(_) + | Value::InternString(_) + | Value::InternBytes(_) + | Value::InternLongInt(_) => true, + Value::Ref(heap_id) => matches!( + vm.heap.get(*heap_id), + HeapData::LongInt(_) + | HeapData::Str(_) + | HeapData::Bytes(_) + | HeapData::Date(_) + | HeapData::DateTime(_) + | HeapData::TimeDelta(_) + | HeapData::TimeZone(_) + ), + _ => false, + } +} + +/// Returns whether a value can be iterated by Monty's iterator protocol. +fn is_numpy_iterable(value: &Value, vm: &VM<'_, impl ResourceTracker>) -> bool { + match value { + Value::InternString(_) | Value::InternBytes(_) => true, + Value::Ref(heap_id) => matches!( + vm.heap.get(*heap_id), + HeapData::List(_) + | HeapData::Tuple(_) + | HeapData::NamedTuple(_) + | HeapData::Dict(_) + | HeapData::DictKeysView(_) + | HeapData::DictItemsView(_) + | HeapData::DictValuesView(_) + | HeapData::Set(_) + | HeapData::FrozenSet(_) + | HeapData::Range(_) + | HeapData::Iter(_) + | HeapData::Str(_) + | HeapData::Bytes(_) + | HeapData::NdArray(_) + ), + _ => false, + } +} + +/// Triangle side selected by `tril_indices*` and `triu_indices*`. +#[derive(Clone, Copy)] +enum TriangleKind { + /// Include coordinates on and below the selected diagonal. + Lower, + /// Include coordinates on and above the selected diagonal. + Upper, +} + +/// Integer index input for ravel/unravel helpers. +/// +/// NumPy returns scalar coordinates for scalar index inputs and arrays for +/// vector inputs. This enum carries the copied integer data plus the shape +/// needed to rebuild that same result form. +enum IndexInput { + /// A single scalar index. + Scalar(i64), + /// A vector/array of indices and the shape to preserve for the output. + Array { data: Vec, shape: Vec }, +} + +/// Shared implementation for `numpy.atleast_1d`, `numpy.atleast_2d`, and `numpy.atleast_3d`. +/// +/// Each input is converted into Monty's numeric ndarray representation and then +/// reshaped by adding length-1 axes according to NumPy's common cases. Multiple +/// inputs return a tuple of arrays, matching NumPy's variadic API. +fn call_atleast_nd( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + min_ndim: usize, + name: &str, +) -> RunResult { + let pos = args.into_pos_only(name, vm.heap)?; + defer_drop_mut!(pos, vm); + + let mut outputs: SmallVec<[Value; 3]> = SmallVec::new(); + for arg in pos.by_ref() { + defer_drop!(arg, vm); + outputs.push(atleast_nd_value(arg, min_ndim, name, vm)?); + } + + if outputs.len() == 1 { + Ok(outputs.pop().expect("one output exists")) + } else { + allocate_tuple(outputs, vm.heap).map_err(Into::into) + } +} + +/// Converts one value for the `atleast_*d` family. +fn atleast_nd_value( + value: &Value, + min_ndim: usize, + name: &str, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let (data, shape, dtype) = if let Ok((data, shape, dtype)) = extract_ndarray_info(value, name, vm) { + (data, shape, dtype) + } else { + let (scalar, dtype) = numeric_scalar_info(value, name, vm)?; + (vec![scalar], Vec::new(), dtype) + }; + let shape = atleast_shape(shape, min_ndim); + let arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Computes NumPy's shape expansion for the supported `atleast_*d` cases. +fn atleast_shape(shape: Vec, min_ndim: usize) -> Vec { + match (min_ndim, shape.as_slice()) { + (1, []) => vec![1], + (1, _) => shape, + (2, []) => vec![1, 1], + (2, [n]) => vec![1, *n], + (2, _) => shape, + (3, []) => vec![1, 1, 1], + (3, [n]) => vec![1, *n, 1], + (3, [rows, cols]) => vec![*rows, *cols, 1], + (3, _) => shape, + _ => shape, + } +} + +/// `numpy.diag_indices(n, ndim=2)` — return repeated diagonal index arrays. +fn call_diag_indices(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (n_val, ndim_val) = args.get_one_two_args("numpy.diag_indices", vm.heap)?; + defer_drop!(n_val, vm); + let n = value_to_nonnegative_usize(n_val, "numpy.diag_indices", "n")?; + let ndim = if let Some(ndim_val) = ndim_val { + defer_drop!(ndim_val, vm); + value_to_nonnegative_usize(ndim_val, "numpy.diag_indices", "ndim")? + } else { + 2 + }; + diag_indices_tuple(n, ndim, vm) +} + +/// `numpy.diag_indices_from(arr)` — diagonal index arrays for a square input. +fn call_diag_indices_from(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.diag_indices_from", vm.heap)?; + defer_drop!(arg, vm); + let (_, shape, _) = extract_ndarray_info(arg, "numpy.diag_indices_from", vm)?; + if shape.len() < 2 { + Err(SimpleException::new_msg(ExcType::ValueError, "input array must be at least 2-d").into()) + } else if !shape.iter().all(|&dim| dim == shape[0]) { + Err(SimpleException::new_msg(ExcType::ValueError, "all dimensions of input must be of equal length").into()) + } else { + diag_indices_tuple(shape[0], shape.len(), vm) + } +} + +/// Builds a tuple containing `ndim` copies of the diagonal range `0..n`. +fn diag_indices_tuple(n: usize, ndim: usize, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let data: Vec = (0..n).map(usize_to_f64).collect(); + let vectors = (0..ndim).map(|_| data.clone()).collect::>(); + tuple_from_index_vectors(vm, vectors, &[n]) +} + +/// `numpy.tril_indices()` / `numpy.triu_indices()` over the supported integer arguments. +fn call_triangle_indices( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + kind: TriangleKind, + name: &str, +) -> RunResult { + let (n, k, m) = triangle_args(args, name, vm)?; + triangle_indices_tuple(n, k, m, kind, vm) +} + +/// `numpy.tril_indices_from()` / `numpy.triu_indices_from()` for 2-D arrays. +fn call_triangle_indices_from( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + kind: TriangleKind, + name: &str, +) -> RunResult { + let (arr_val, k_val) = args.get_one_two_args(name, vm.heap)?; + defer_drop!(arr_val, vm); + let (_, shape, _) = extract_ndarray_info(arr_val, name, vm)?; + let k = if let Some(k_val) = k_val { + defer_drop!(k_val, vm); + value_to_i64_arg(k_val, name, "k")? + } else { + 0 + }; + if shape.len() == 2 { + triangle_indices_tuple(shape[0], k, shape[1], kind, vm) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "input array must be 2-d").into()) + } +} + +/// Parses `(n, k=0, m=None)` for triangle index helpers. +fn triangle_args(args: ArgValues, name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult<(usize, i64, usize)> { + let pos = args.into_pos_only(name, vm.heap)?; + defer_drop_mut!(pos, vm); + let n_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 1, 0))?; + defer_drop!(n_val, vm); + let n = value_to_nonnegative_usize(n_val, name, "n")?; + let k = if let Some(k_val) = pos.next() { + defer_drop!(k_val, vm); + value_to_i64_arg(k_val, name, "k")? + } else { + 0 + }; + let m = if let Some(m_val) = pos.next() { + defer_drop!(m_val, vm); + if matches!(m_val, Value::None) { + n + } else { + value_to_nonnegative_usize(m_val, name, "m")? + } + } else { + n + }; + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most(name, 3, 4)); + } + Ok((n, k, m)) +} + +/// Builds lower- or upper-triangle row and column index arrays. +fn triangle_indices_tuple( + n: usize, + k: i64, + m: usize, + kind: TriangleKind, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let mut rows = Vec::new(); + let mut cols = Vec::new(); + for row in 0..n { + let row_i64 = usize_to_i64(row)?; + for col in 0..m { + let col_i64 = usize_to_i64(col)?; + let include = match kind { + TriangleKind::Lower => col_i64 <= row_i64.saturating_add(k), + TriangleKind::Upper => col_i64 >= row_i64.saturating_add(k), + }; + if include { + rows.push(usize_to_f64(row)); + cols.push(usize_to_f64(col)); + } + } + } + let len = cols.len(); + tuple_from_index_vectors(vm, vec![rows, cols], &[len]) +} + +/// `numpy.indices(dimensions)` — build dense integer coordinate grids. +fn call_indices(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (dims_val, dtype_val) = args.get_one_two_args("numpy.indices", vm.heap)?; + defer_drop!(dims_val, vm); + if let Some(dtype_val) = dtype_val { + dtype_val.drop_with_heap(vm); + } + let dimensions = extract_shape_from_value(dims_val, "numpy.indices", vm)?; + let ndim = dimensions.len(); + let total = checked_shape_product(&dimensions, "numpy.indices")?; + check_array_alloc_size(total.saturating_mul(ndim), vm.heap.tracker())?; + + let mut data = Vec::with_capacity(total.saturating_mul(ndim)); + if total > 0 { + for axis in 0..ndim { + let stride = checked_shape_product(&dimensions[axis + 1..], "numpy.indices")?; + for flat in 0..total { + let coord = if dimensions[axis] == 0 { + 0 + } else { + (flat / stride) % dimensions[axis] + }; + data.push(usize_to_f64(coord)); + } + } + } + let mut shape = Vec::with_capacity(ndim + 1); + shape.push(ndim); + shape.extend(dimensions); + let arr = NdArray::new(data, shape, NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.unravel_index(indices, shape)` — convert flat indices to coordinates. +fn call_unravel_index(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (indices_val, shape_val) = args.get_two_args("numpy.unravel_index", vm.heap)?; + defer_drop!(indices_val, vm); + defer_drop!(shape_val, vm); + let dimensions = extract_shape_from_value(shape_val, "numpy.unravel_index", vm)?; + let index_input = index_input_info(indices_val, "numpy.unravel_index", vm)?; + let total = checked_shape_product(&dimensions, "numpy.unravel_index")?; + + match index_input { + IndexInput::Scalar(index) => { + let coords = unravel_one_index(index, &dimensions, total, "numpy.unravel_index")?; + let values: SmallVec<[Value; 3]> = coords.into_iter().map(Value::Int).collect(); + allocate_tuple(values, vm.heap).map_err(Into::into) + } + IndexInput::Array { data, shape } => { + let mut vectors = vec![Vec::with_capacity(data.len()); dimensions.len()]; + for index in data { + let coords = unravel_one_index(index, &dimensions, total, "numpy.unravel_index")?; + for (axis, coord) in coords.into_iter().enumerate() { + vectors[axis].push(i64_to_f64(coord)); + } + } + tuple_from_index_vectors(vm, vectors, &shape) + } + } +} + +/// `numpy.ravel_multi_index(multi_index, dims)` — convert coordinates to flat indices. +fn call_ravel_multi_index(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (multi_val, dims_val) = args.get_two_args("numpy.ravel_multi_index", vm.heap)?; + defer_drop!(multi_val, vm); + defer_drop!(dims_val, vm); + + let dimensions = extract_shape_from_value(dims_val, "numpy.ravel_multi_index", vm)?; + let coord_values = sequence_items(multi_val, "numpy.ravel_multi_index", vm)?; + defer_drop!(coord_values, vm); + if coord_values.len() != dimensions.len() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "parameter multi_index must be a sequence of length matching dims", + ) + .into()); + } + + let coords = coord_values + .iter() + .map(|value| index_input_info(value, "numpy.ravel_multi_index", vm)) + .collect::>>()?; + ravel_multi_index_result(&coords, &dimensions, vm) +} + +/// `numpy.ndindex(*shape)` — return row-major coordinate tuples for a shape. +/// +/// NumPy returns a dedicated iterator object, but Monty's subset materializes the +/// same finite coordinate sequence as a list. That preserves normal iteration and +/// `list(np.ndindex(...))` behavior without adding another heap iterator type. +fn call_ndindex(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.ndindex", vm.heap)?; + defer_drop_mut!(pos, vm); + + let shape = match pos.len() { + 0 => Vec::new(), + 1 => extract_shape_from_value(&pos.as_slice()[0], "numpy.ndindex", vm)?, + _ => extract_shape_from_items(pos.as_slice(), "numpy.ndindex")?, + }; + + coordinate_tuple_list(&shape, "numpy.ndindex", vm) +} + +/// `numpy.ndenumerate(a)` — return row-major `(index, value)` pairs for an array. +/// +/// The value side is converted back to the closest supported Monty scalar dtype, +/// matching the rest of the ndarray helpers that expose individual elements. +fn call_ndenumerate(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.ndenumerate", vm.heap)?; + defer_drop!(arg, vm); + + let arr = ndarray_or_scalar_from_value(arg, "numpy.ndenumerate", vm)?; + let total = checked_shape_product(arr.shape(), "numpy.ndenumerate")?; + check_array_alloc_size( + total.saturating_mul(arr.shape().len().saturating_add(1)), + vm.heap.tracker(), + )?; + + let strides = row_major_strides(arr.shape()); + let mut parts = Vec::with_capacity(arr.len()); + for (flat, value) in arr.data().iter().copied().enumerate() { + let coords = coords_from_flat_index(flat, arr.shape(), &strides); + let index = coordinate_tuple(&coords, vm)?; + let value = scalar_from_f64(value, arr.dtype()); + let pair = allocate_tuple(smallvec::smallvec![index, value], vm.heap)?; + parts.push(pair); + } + + Ok(Value::Ref(vm.heap.allocate(HeapData::List(List::new(parts)))?)) +} + +/// `numpy.nditer(a)` — return row-major scalar values for an array. +/// +/// This intentionally implements only the simple single-array form. Broader +/// `nditer` options depend on writable views, casting modes, and multi-operand +/// iteration that Monty's ndarray model does not currently expose. +fn call_nditer(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.nditer", vm.heap)?; + defer_drop!(arg, vm); + + let arr = ndarray_or_scalar_from_value(arg, "numpy.nditer", vm)?; + check_array_alloc_size(arr.len(), vm.heap.tracker())?; + let parts = arr + .data() + .iter() + .copied() + .map(|value| scalar_from_f64(value, arr.dtype())) + .collect(); + + Ok(Value::Ref(vm.heap.allocate(HeapData::List(List::new(parts)))?)) +} + +/// Builds a materialized list of coordinate tuples for a row-major shape walk. +fn coordinate_tuple_list(shape: &[usize], name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let total = checked_shape_product(shape, name)?; + check_array_alloc_size(total.saturating_mul(shape.len().max(1)), vm.heap.tracker())?; + + let strides = row_major_strides(shape); + let mut parts = Vec::with_capacity(total); + for flat in 0..total { + let coords = coords_from_flat_index(flat, shape, &strides); + parts.push(coordinate_tuple(&coords, vm)?); + } + + Ok(Value::Ref(vm.heap.allocate(HeapData::List(List::new(parts)))?)) +} + +/// Allocates one coordinate tuple from usize components as Python integers. +fn coordinate_tuple(coords: &[usize], vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let mut values: SmallVec<[Value; 3]> = SmallVec::new(); + for &coord in coords { + values.push(Value::Int(usize_to_i64(coord)?)); + } + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// Computes the scalar or array output for `ravel_multi_index`. +fn ravel_multi_index_result( + coords: &[IndexInput], + dimensions: &[usize], + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let array_shape = shared_index_array_shape(coords)?; + if let Some(shape) = array_shape { + let len = shape.iter().product::(); + let mut data = Vec::with_capacity(len); + for offset in 0..len { + let mut coord_at_offset = Vec::with_capacity(coords.len()); + for coord in coords { + coord_at_offset.push(index_input_value_at(coord, offset)); + } + data.push(i64_to_f64(ravel_one_index( + &coord_at_offset, + dimensions, + "numpy.ravel_multi_index", + )?)); + } + let arr = NdArray::new(data, shape, NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let coord_at_offset = coords + .iter() + .map(|coord| index_input_value_at(coord, 0)) + .collect::>(); + Ok(Value::Int(ravel_one_index( + &coord_at_offset, + dimensions, + "numpy.ravel_multi_index", + )?)) + } +} + +/// Extracts a scalar or integer array from an index-like value. +fn index_input_info(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + if let Ok((data, shape)) = integer_array_info(value, name, vm) { + Ok(IndexInput::Array { + data: data.into_iter().map(f64_to_i64).collect(), + shape, + }) + } else { + integer_scalar_info(value, name).map(IndexInput::Scalar) + } +} + +/// Returns the common array shape among index inputs, if any input is array-shaped. +fn shared_index_array_shape(coords: &[IndexInput]) -> RunResult>> { + let mut shape = None; + for coord in coords { + if let IndexInput::Array { shape: coord_shape, .. } = coord { + if let Some(existing) = &shape { + if existing != coord_shape { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "operands could not be broadcast together", + ) + .into()); + } + } else { + shape = Some(coord_shape.clone()); + } + } + } + Ok(shape) +} + +/// Reads the scalar or per-offset array coordinate from an index input. +fn index_input_value_at(input: &IndexInput, offset: usize) -> i64 { + match input { + IndexInput::Scalar(value) => *value, + IndexInput::Array { data, .. } => data[offset], + } +} + +/// Converts one flat index into row-major coordinates for `unravel_index`. +fn unravel_one_index(index: i64, dimensions: &[usize], total: usize, name: &str) -> RunResult> { + let mut index = nonnegative_index_in_bounds(index, total, name)?; + let mut coords = vec![0; dimensions.len()]; + for axis in (0..dimensions.len()).rev() { + let dim = dimensions[axis]; + if dim == 0 { + return Err( + SimpleException::new_msg(ExcType::ValueError, "cannot unravel if shape has zero entries").into(), + ); + } + coords[axis] = usize_to_i64(index % dim)?; + index /= dim; + } + Ok(coords) +} + +/// Converts one coordinate tuple into a row-major flat index. +fn ravel_one_index(coords: &[i64], dimensions: &[usize], name: &str) -> RunResult { + let mut flat = 0usize; + for (&coord, &dim) in coords.iter().zip(dimensions.iter()) { + let coord = nonnegative_index_in_bounds(coord, dim, name)?; + flat = flat + .checked_mul(dim) + .and_then(|value| value.checked_add(coord)) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "index dimensions overflow"))?; + } + usize_to_i64(flat) +} + +/// Checks an index is non-negative and inside a dimension or total-size bound. +fn nonnegative_index_in_bounds(index: i64, upper: usize, name: &str) -> RunResult { + let index = i64_to_nonnegative_usize(index, name, "index")?; + if index >= upper { + Err(SimpleException::new_msg(ExcType::ValueError, "invalid entry in coordinates array").into()) + } else { + Ok(index) + } +} + +/// Extracts list/tuple items from a value by cloning references safely. +fn sequence_items(value: &Value, name: &str, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult> { + match value { + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => Ok(list.as_slice().iter().map(|value| value.clone_with_heap(vm)).collect()), + HeapData::Tuple(tuple) => Ok(tuple.as_slice().iter().map(|value| value.clone_with_heap(vm)).collect()), + _ => Err(ExcType::type_error(format!("{name}() requires a sequence argument"))), + }, + _ => Err(ExcType::type_error(format!("{name}() requires a sequence argument"))), + } +} + +/// Allocates a tuple of integer ndarrays using a shared result shape. +fn tuple_from_index_vectors( + vm: &mut VM<'_, impl ResourceTracker>, + vectors: Vec>, + shape: &[usize], +) -> RunResult { + let mut values: SmallVec<[Value; 3]> = SmallVec::new(); + for data in vectors { + let arr = NdArray::new(data, shape.to_vec(), NdArrayDtype::Int64); + values.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)); + } + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// Computes a shape product with a NumPy-style overflow error. +fn checked_shape_product(shape: &[usize], name: &str) -> RunResult { + shape + .iter() + .try_fold(1usize, |acc, &dim| acc.checked_mul(dim)) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, format!("{name}() dimensions overflow")).into()) +} + +/// Converts an integer `Value` into a non-negative usize argument. +fn value_to_nonnegative_usize(value: &Value, name: &str, arg_name: &str) -> RunResult { + let value = value_to_i64_arg(value, name, arg_name)?; + i64_to_nonnegative_usize(value, name, arg_name) +} + +/// Converts an integer `Value` into an i64 argument. +fn value_to_i64_arg(value: &Value, name: &str, arg_name: &str) -> RunResult { + match value { + Value::Int(value) => Ok(*value), + _ => Err(ExcType::type_error(format!("{name}() {arg_name} must be an integer"))), + } +} + +/// Converts a non-negative i64 into usize with a targeted ValueError. +fn i64_to_nonnegative_usize(value: i64, name: &str, arg_name: &str) -> RunResult { + if value < 0 { + Err(SimpleException::new_msg(ExcType::ValueError, format!("{name}() {arg_name} must be non-negative")).into()) + } else { + usize::try_from(value).map_err(|_| { + SimpleException::new_msg(ExcType::ValueError, format!("{name}() {arg_name} is too large")).into() + }) + } +} + +/// Converts a usize index into i64 for Python integer outputs. +fn usize_to_i64(value: usize) -> RunResult { + i64::try_from(value).map_err(|_| SimpleException::new_msg(ExcType::ValueError, "index is too large").into()) +} + +/// Converts a usize index into ndarray f64 backing storage. +#[expect( + clippy::cast_precision_loss, + reason = "integer ndarray values are stored as f64 in Monty's current ndarray model" +)] +fn usize_to_f64(value: usize) -> f64 { + value as f64 +} + +/// Shared implementation for unary NumPy functions that return two results. +/// +/// NumPy's `frexp()` and `modf()` preserve the input's scalar-vs-array form but +/// package the two outputs in a tuple. This helper keeps that shape handling in +/// one place so both scalar broadcasting and list-to-array conversion match the +/// rest of Monty's ufunc subset. +fn call_unary_tuple_func( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(f64) -> (f64, f64), + name: &str, + first_dtype: NdArrayDtype, + second_dtype: NdArrayDtype, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + + if let Ok((data, shape, _)) = extract_ndarray_info(arg, name, vm) { + let (first_data, second_data): (Vec, Vec) = data.iter().map(|&value| f(value)).unzip(); + tuple_from_arrays(vm, first_data, second_data, shape, first_dtype, second_dtype) + } else { + let (value, _) = numeric_scalar_info(arg, name, vm)?; + let (first, second) = f(value); + tuple_from_scalars(first, second, first_dtype, second_dtype, vm) + } +} + +/// `numpy.ldexp(x, exp)` over Monty's numeric scalar/list/ndarray subset. +/// +/// The exponent operand is intentionally restricted to integer and boolean +/// dtypes, matching NumPy's ufunc loop selection and preventing accidental +/// coercion of arbitrary floats into powers-of-two exponents. +fn call_ldexp(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (x_val, exp_val) = args.get_two_args("numpy.ldexp", vm.heap)?; + defer_drop!(x_val, vm); + defer_drop!(exp_val, vm); + + let x_arr = ndarray_or_scalar_from_value(x_val, "numpy.ldexp", vm)?; + let exp_arr = integer_ndarray_or_scalar_from_value(exp_val, "numpy.ldexp", vm)?; + if x_arr.shape().is_empty() && exp_arr.shape().is_empty() { + Ok(Value::Float(numpy_ldexp(x_arr.data()[0], exp_arr.data()[0]))) + } else { + let (x_data, exp_data, shape) = broadcast_pair_data( + x_arr.data(), + x_arr.shape(), + exp_arr.data(), + exp_arr.shape(), + "numpy.ldexp", + vm.heap.tracker(), + )?; + let data: Vec = x_data + .iter() + .zip(exp_data.iter()) + .map(|(&x, &exp)| numpy_ldexp(x, exp)) + .collect(); + let arr = NdArray::new(data, shape, NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } +} + +/// Shared implementation for integer-only binary ufuncs like `gcd()` and `lcm()`. +/// +/// Float dtypes are rejected instead of being truncated, because real NumPy has +/// no safe float loop for these ufuncs. Boolean inputs are accepted and promoted +/// to integer results. +fn call_integer_binop( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(i64, i64) -> i64, + name: &str, +) -> RunResult { + let (a_val, b_val) = args.get_two_args(name, vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + + let a_arr = integer_ndarray_or_scalar_from_value(a_val, name, vm)?; + let b_arr = integer_ndarray_or_scalar_from_value(b_val, name, vm)?; + if a_arr.shape().is_empty() && b_arr.shape().is_empty() { + Ok(Value::Int(f(f64_to_i64(a_arr.data()[0]), f64_to_i64(b_arr.data()[0])))) + } else { + let (left, right, shape) = broadcast_pair_data( + a_arr.data(), + a_arr.shape(), + b_arr.data(), + b_arr.shape(), + name, + vm.heap.tracker(), + )?; + let data: Vec = left + .iter() + .zip(right.iter()) + .map(|(&a, &b)| i64_to_f64(f(f64_to_i64(a), f64_to_i64(b)))) + .collect(); + let arr = NdArray::new(data, shape, NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } +} + +/// Integer/boolean bitwise binary operation exposed as a NumPy ufunc. +#[derive(Clone, Copy)] +enum IntegerBitwiseOp { + /// Element-wise `a & b`. + And, + /// Element-wise `a | b`. + Or, + /// Element-wise `a ^ b`. + Xor, + /// Element-wise `a << b` using NumPy's fixed-width integer behavior. + LeftShift, + /// Element-wise `a >> b` using NumPy's fixed-width integer behavior. + RightShift, +} + +/// Shared implementation for NumPy's integer-only bitwise binary ufuncs. +/// +/// Float inputs are rejected, scalar broadcasting is supported, and boolean +/// AND/OR/XOR preserves bool dtype when both operands are boolean-valued. +fn call_bitwise_binop( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + op: IntegerBitwiseOp, + name: &str, +) -> RunResult { + let (a_val, b_val) = args.get_two_args(name, vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + + let a_arr = integer_ndarray_or_scalar_from_value(a_val, name, vm)?; + let b_arr = integer_ndarray_or_scalar_from_value(b_val, name, vm)?; + let dtype = bitwise_binop_dtype(op, a_arr.dtype(), b_arr.dtype()); + if a_arr.shape().is_empty() && b_arr.shape().is_empty() { + Ok(scalar_from_integer_result( + apply_integer_bitwise_op(op, f64_to_i64(a_arr.data()[0]), f64_to_i64(b_arr.data()[0])), + dtype, + )) + } else { + let (left, right, shape) = broadcast_pair_data( + a_arr.data(), + a_arr.shape(), + b_arr.data(), + b_arr.shape(), + name, + vm.heap.tracker(), + )?; + let data = left + .iter() + .zip(right.iter()) + .map(|(&a, &b)| i64_to_f64(apply_integer_bitwise_op(op, f64_to_i64(a), f64_to_i64(b)))) + .collect(); + let arr = NdArray::new(data, shape, dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } +} + +/// `numpy.bitwise_not()` / `numpy.invert()` over integer and boolean inputs. +fn call_bitwise_not(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.bitwise_not", vm.heap)?; + defer_drop!(arg, vm); + + if let Ok((data, shape, dtype)) = integer_array_info_with_dtype(arg, "numpy.bitwise_not", vm) { + let result_dtype = bitwise_not_dtype(dtype); + let data = data + .iter() + .map(|&value| i64_to_f64(bitwise_not_value(f64_to_i64(value), dtype))) + .collect(); + let arr = NdArray::new(data, shape, result_dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let (value, dtype) = integer_scalar_info_with_dtype(arg, "numpy.bitwise_not")?; + let result_dtype = bitwise_not_dtype(dtype); + Ok(scalar_from_integer_result( + bitwise_not_value(value, dtype), + result_dtype, + )) + } +} + +/// `numpy.bitwise_count()` — population count of each integer's absolute value. +fn call_bitwise_count(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.bitwise_count", vm.heap)?; + defer_drop!(arg, vm); + + if let Ok((data, shape, _)) = integer_array_info_with_dtype(arg, "numpy.bitwise_count", vm) { + let data = data + .iter() + .map(|&value| i64_to_f64(numpy_bitwise_count(f64_to_i64(value)))) + .collect(); + let arr = NdArray::new(data, shape, NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + let value = integer_scalar_info(arg, "numpy.bitwise_count")?; + Ok(Value::Int(numpy_bitwise_count(value))) + } +} + +/// `numpy.packbits()` — pack flattened non-zero integer values into big-endian bytes. +fn call_packbits(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.packbits", vm.heap)?; + defer_drop!(arg, vm); + + let bits = if let Ok((data, _, _)) = integer_array_info_with_dtype(arg, "numpy.packbits", vm) { + data.into_iter().map(|value| f64_to_i64(value) != 0).collect() + } else { + vec![integer_scalar_info(arg, "numpy.packbits")? != 0] + }; + let output_len = bits.len().div_ceil(8); + check_array_alloc_size(output_len, vm.heap.tracker())?; + let data = pack_big_endian_bits(&bits); + let arr = NdArray::new(data, vec![output_len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.unpackbits()` — unpack flattened byte-sized integer values into bits. +/// +/// Monty does not currently model `uint8`, so this accepts integer arrays whose +/// values are in the byte range. That keeps `unpackbits(packbits(x))` useful +/// while still rejecting floats and out-of-range values. +fn call_unpackbits(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.unpackbits", vm.heap)?; + defer_drop!(arg, vm); + let (data, _, dtype) = integer_array_info_with_dtype(arg, "numpy.unpackbits", vm)?; + if dtype == NdArrayDtype::Bool { + return Err(unpackbits_type_error()); + } + let output_len = data + .len() + .checked_mul(8) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.unpackbits() output is too large"))?; + check_array_alloc_size(output_len, vm.heap.tracker())?; + let mut bits = Vec::with_capacity(output_len); + for value in data { + let byte = byte_from_integer_slot(value)?; + for bit in (0..8).rev() { + bits.push(f64::from((byte >> bit) & 1)); + } + } + let arr = NdArray::new(bits, vec![output_len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Determines the dtype for a bitwise binary ufunc. +fn bitwise_binop_dtype(op: IntegerBitwiseOp, a: NdArrayDtype, b: NdArrayDtype) -> NdArrayDtype { + match op { + IntegerBitwiseOp::And | IntegerBitwiseOp::Or | IntegerBitwiseOp::Xor + if a == NdArrayDtype::Bool && b == NdArrayDtype::Bool => + { + NdArrayDtype::Bool + } + IntegerBitwiseOp::And + | IntegerBitwiseOp::Or + | IntegerBitwiseOp::Xor + | IntegerBitwiseOp::LeftShift + | IntegerBitwiseOp::RightShift => NdArrayDtype::Int64, + } +} + +/// Applies a single integer bitwise operation with NumPy-style shift edges. +fn apply_integer_bitwise_op(op: IntegerBitwiseOp, a: i64, b: i64) -> i64 { + match op { + IntegerBitwiseOp::And => a & b, + IntegerBitwiseOp::Or => a | b, + IntegerBitwiseOp::Xor => a ^ b, + IntegerBitwiseOp::LeftShift => numpy_left_shift(a, b), + IntegerBitwiseOp::RightShift => numpy_right_shift(a, b), + } +} + +/// NumPy-style fixed-width left shift for signed 64-bit integer loops. +fn numpy_left_shift(value: i64, shift: i64) -> i64 { + if (0..64).contains(&shift) { + value.wrapping_shl(u32::try_from(shift).expect("shift count is in range")) + } else { + 0 + } +} + +/// NumPy-style fixed-width arithmetic right shift for signed 64-bit integer loops. +fn numpy_right_shift(value: i64, shift: i64) -> i64 { + if (0..64).contains(&shift) { + value >> u32::try_from(shift).expect("shift count is in range") + } else if value < 0 { + -1 + } else { + 0 + } +} + +/// Computes the scalar/container dtype for bitwise inversion. +fn bitwise_not_dtype(dtype: NdArrayDtype) -> NdArrayDtype { + if dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + } +} + +/// Applies unary bitwise inversion to a bool or int slot. +fn bitwise_not_value(value: i64, dtype: NdArrayDtype) -> i64 { + if dtype == NdArrayDtype::Bool { + i64::from(value == 0) + } else { + !value + } +} + +/// Converts an integer ufunc result back to a scalar value with bool preservation. +fn scalar_from_integer_result(value: i64, dtype: NdArrayDtype) -> Value { + if dtype == NdArrayDtype::Bool { + Value::Bool(value != 0) + } else { + Value::Int(value) + } +} + +/// Population count matching `numpy.bitwise_count`, which counts `abs(x)`. +fn numpy_bitwise_count(value: i64) -> i64 { + i64::from(value.unsigned_abs().count_ones()) +} + +/// Packs a flattened bit stream into byte values using NumPy's default big bit order. +fn pack_big_endian_bits(bits: &[bool]) -> Vec { + let mut packed = Vec::with_capacity(bits.len().div_ceil(8)); + for chunk in bits.chunks(8) { + let mut byte = 0u8; + for (index, bit) in chunk.iter().enumerate() { + if *bit { + byte |= 1 << (7 - index); + } + } + packed.push(f64::from(byte)); + } + packed +} + +/// Extracts one byte from Monty's integer ndarray storage for `unpackbits`. +fn byte_from_integer_slot(value: f64) -> RunResult { + let value = f64_to_i64(value); + u8::try_from(value).map_err(|_| unpackbits_type_error()) +} + +/// TypeError used when `unpackbits` input cannot represent unsigned bytes. +fn unpackbits_type_error() -> RunError { + SimpleException::new_msg(ExcType::TypeError, "Expected an input array of unsigned byte data type").into() +} + +/// Supported one-argument NumPy window generators. +#[derive(Clone, Copy)] +enum WindowKind { + /// Bartlett triangular window. + Bartlett, + /// Blackman taper window. + Blackman, + /// Hamming raised-cosine window. + Hamming, + /// Hann raised-cosine window using NumPy's `hanning` spelling. + Hanning, +} + +/// Shared implementation for NumPy's simple floating-point window generators. +fn call_window( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + kind: WindowKind, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let len = window_len(arg, name)?; + check_array_alloc_size(len, vm.heap.tracker())?; + let data = window_values(len, kind); + let arr = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.kaiser(M, beta)` — Kaiser window using the supported real-valued subset. +fn call_kaiser(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (m_val, beta_val) = args.get_two_args("numpy.kaiser", vm.heap)?; + defer_drop!(m_val, vm); + defer_drop!(beta_val, vm); + let len = window_len(m_val, "numpy.kaiser")?; + let (beta, _) = numeric_scalar_info(beta_val, "numpy.kaiser", vm)?; + check_array_alloc_size(len, vm.heap.tracker())?; + let data = kaiser_values(len, beta); + let arr = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Parses a NumPy window length, where non-positive lengths produce an empty array. +fn window_len(value: &Value, name: &str) -> RunResult { + match value { + Value::Int(m) if *m <= 0 => Ok(0), + Value::Int(m) => usize::try_from(*m).map_err(|_| { + SimpleException::new_msg(ExcType::ValueError, format!("{name}() window length is too large")).into() + }), + _ => Err(ExcType::type_error(format!( + "{name}() window length must be an integer" + ))), + } +} + +/// Generates values for one of NumPy's one-argument real windows. +fn window_values(len: usize, kind: WindowKind) -> Vec { + match len { + 0 => Vec::new(), + 1 => vec![1.0], + _ => { + let denom = usize_to_f64(len - 1); + (0..len) + .map(|index| { + let n = usize_to_f64(index); + let phase = 2.0 * PI * n / denom; + match kind { + WindowKind::Bartlett => 1.0 - ((n - denom / 2.0) / (denom / 2.0)).abs(), + WindowKind::Blackman => 0.42 - 0.5 * phase.cos() + 0.08 * (2.0 * phase).cos(), + WindowKind::Hamming => 0.54 - 0.46 * phase.cos(), + WindowKind::Hanning => 0.5 - 0.5 * phase.cos(), + } + }) + .collect() + } + } +} + +/// Generates a Kaiser window using the order-0 modified Bessel approximation. +fn kaiser_values(len: usize, beta: f64) -> Vec { + match len { + 0 => Vec::new(), + 1 => vec![1.0], + _ => { + let alpha = usize_to_f64(len - 1) / 2.0; + let denom = numpy_i0(beta); + (0..len) + .map(|index| { + let ratio = (usize_to_f64(index) - alpha) / alpha; + let inner = (1.0 - ratio * ratio).max(0.0).sqrt(); + numpy_i0(beta * inner) / denom + }) + .collect() + } + } +} + +/// `numpy.base_repr(number, base=2, padding=0)` — convert an integer to a base string. +fn call_base_repr(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.base_repr", vm.heap)?; + defer_drop_mut!(pos, vm); + let number_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.base_repr", 1, 0))?; + defer_drop!(number_val, vm); + let number = value_to_i64_arg(number_val, "numpy.base_repr", "number")?; + let base = if let Some(base_val) = pos.next() { + defer_drop!(base_val, vm); + value_to_i64_arg(base_val, "numpy.base_repr", "base")? + } else { + 2 + }; + let padding = if let Some(padding_val) = pos.next() { + defer_drop!(padding_val, vm); + value_to_i64_arg(padding_val, "numpy.base_repr", "padding")? + } else { + 0 + }; + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.base_repr", 3, 4)); + } + let result = format_base_repr(number, base, padding)?; + allocate_string(result, vm.heap) +} + +/// `numpy.binary_repr(num, width=None)` — convert an integer to a binary string. +fn call_binary_repr(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (num_val, width_val) = args.get_one_two_args("numpy.binary_repr", vm.heap)?; + defer_drop!(num_val, vm); + let num = value_to_i64_arg(num_val, "numpy.binary_repr", "num")?; + let width = if let Some(width_val) = width_val { + defer_drop!(width_val, vm); + if matches!(width_val, Value::None) { + None + } else { + Some(value_to_i64_arg(width_val, "numpy.binary_repr", "width")?) + } + } else { + None + }; + let result = format_binary_repr(num, width)?; + allocate_string(result, vm.heap) +} + +/// Formats `base_repr`, including NumPy's base limits and padding behavior. +fn format_base_repr(number: i64, base: i64, padding: i64) -> RunResult { + let base = validate_base_repr_base(base)?; + let padding = nonnegative_padding(padding)?; + let magnitude = u128::from(number.unsigned_abs()); + let digits = format_unsigned_base(magnitude, base); + let zero_count = if magnitude == 0 { + padding.saturating_sub(1) + } else { + padding + }; + let zeros = "0".repeat(zero_count); + let sign = if number < 0 { "-" } else { "" }; + Ok(format!("{sign}{zeros}{digits}")) +} + +/// Formats `binary_repr`, including two's-complement output when width is supplied. +fn format_binary_repr(num: i64, width: Option) -> RunResult { + let magnitude_digits = format_unsigned_base(u128::from(num.unsigned_abs()), 2); + if let Some(width) = width { + let width_usize = binary_width(width)?; + let needed_width = if num < 0 { + magnitude_digits.len().saturating_add(1) + } else { + magnitude_digits.len() + }; + if width_usize < needed_width { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!("Insufficient bit width={width} provided for binwidth={needed_width}"), + ) + .into()); + } + if num < 0 { + let value = twos_complement_value(num, width_usize)?; + Ok(left_pad_zeros(format_unsigned_base(value, 2), width_usize)) + } else { + Ok(left_pad_zeros(magnitude_digits, width_usize)) + } + } else if num < 0 { + Ok(format!("-{magnitude_digits}")) + } else { + Ok(magnitude_digits) + } +} + +/// Validates the base accepted by `base_repr`. +fn validate_base_repr_base(base: i64) -> RunResult { + if base < 2 { + Err(SimpleException::new_msg(ExcType::ValueError, "Bases less than 2 not handled in base_repr.").into()) + } else if base > 36 { + Err(SimpleException::new_msg(ExcType::ValueError, "Bases greater than 36 not handled in base_repr.").into()) + } else { + u32::try_from(base).map_err(|_| SimpleException::new_msg(ExcType::ValueError, "invalid base").into()) + } +} + +/// Converts NumPy's `base_repr` padding argument to a repeat count. +fn nonnegative_padding(padding: i64) -> RunResult { + if padding <= 0 { + Ok(0) + } else { + usize::try_from(padding) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "base_repr() padding is too large").into()) + } +} + +/// Converts and validates a `binary_repr` width. +fn binary_width(width: i64) -> RunResult { + if width < 0 { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("Insufficient bit width={width} provided for binwidth=1"), + ) + .into()) + } else { + usize::try_from(width) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "binary_repr() width is too large").into()) + } +} + +/// Computes a negative integer's two's-complement value for a requested width. +fn twos_complement_value(num: i64, width: usize) -> RunResult { + if width > 127 { + Err(SimpleException::new_msg(ExcType::ValueError, "binary_repr() width is too large").into()) + } else { + let modulus = 1_u128 + .checked_shl(u32::try_from(width).expect("width is bounded")) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "binary_repr() width is too large"))?; + Ok(modulus - u128::from(num.unsigned_abs())) + } +} + +/// Left-pads a string with zeros up to `width`. +fn left_pad_zeros(mut value: String, width: usize) -> String { + if value.len() < width { + let mut padded = "0".repeat(width - value.len()); + padded.push_str(&value); + value = padded; + } + value +} + +/// Formats an unsigned integer in bases 2 through 36. +fn format_unsigned_base(mut value: u128, base: u32) -> String { + const DIGITS: &[u8; 36] = b"0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; + if value == 0 { + return "0".to_owned(); + } + let base = u128::from(base); + let mut out = Vec::new(); + while value > 0 { + let digit = usize::try_from(value % base).expect("digit is less than base"); + out.push(char::from(DIGITS[digit])); + value /= base; + } + out.iter().rev().collect() +} + +/// Shared implementation for binary NumPy functions that return two results. +/// +/// `numpy.divmod()` is the motivating case: each operand can be a scalar, list, +/// or ndarray, and the quotient and remainder outputs must preserve the same +/// broadcasted shape while being returned as a pair. +fn call_numeric_tuple_binop( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(f64, f64) -> (f64, f64), + name: &str, + first_result: BinopResult, + second_result: BinopResult, +) -> RunResult { + let (a_val, b_val) = args.get_two_args(name, vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + + let a_arr = ndarray_or_scalar_from_value(a_val, name, vm)?; + let b_arr = ndarray_or_scalar_from_value(b_val, name, vm)?; + let first_dtype = binop_dtype(first_result, a_arr.dtype(), b_arr.dtype()); + let second_dtype = binop_dtype(second_result, a_arr.dtype(), b_arr.dtype()); + + if a_arr.shape().is_empty() && b_arr.shape().is_empty() { + let (first, second) = f(a_arr.data()[0], b_arr.data()[0]); + tuple_from_scalars(first, second, first_dtype, second_dtype, vm) + } else { + let (left, right, shape) = broadcast_pair_data( + a_arr.data(), + a_arr.shape(), + b_arr.data(), + b_arr.shape(), + name, + vm.heap.tracker(), + )?; + let (first_data, second_data): (Vec, Vec) = + left.iter().zip(right.iter()).map(|(&a, &b)| f(a, b)).unzip(); + tuple_from_arrays(vm, first_data, second_data, shape, first_dtype, second_dtype) + } +} + +/// Allocates a tuple containing two ndarray outputs with a shared shape. +fn tuple_from_arrays( + vm: &mut VM<'_, impl ResourceTracker>, + first_data: Vec, + second_data: Vec, + shape: Vec, + first_dtype: NdArrayDtype, + second_dtype: NdArrayDtype, +) -> RunResult { + let first_arr = NdArray::new(first_data, shape.clone(), first_dtype); + let second_arr = NdArray::new(second_data, shape, second_dtype); + let first = Value::Ref(vm.heap.allocate(HeapData::NdArray(first_arr))?); + let second = Value::Ref(vm.heap.allocate(HeapData::NdArray(second_arr))?); + Ok(allocate_tuple(smallvec::smallvec![first, second], vm.heap)?) +} + +/// Allocates a tuple containing two scalar ufunc outputs. +fn tuple_from_scalars( + first: f64, + second: f64, + first_dtype: NdArrayDtype, + second_dtype: NdArrayDtype, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + Ok(allocate_tuple( + smallvec::smallvec![ + scalar_from_f64(first, first_dtype), + scalar_from_f64(second, second_dtype) + ], + vm.heap, + )?) +} + +/// Extracts an integer scalar accepted by NumPy's integer-only ufunc loops. +fn integer_scalar_info(value: &Value, name: &str) -> RunResult { + integer_scalar_info_with_dtype(value, name).map(|(value, _)| value) +} + +/// Extracts an integer scalar plus the dtype NumPy would infer for it. +fn integer_scalar_info_with_dtype(value: &Value, name: &str) -> RunResult<(i64, NdArrayDtype)> { + match value { + Value::Int(n) => Ok((*n, NdArrayDtype::Int64)), + Value::Bool(b) => Ok((i64::from(*b), NdArrayDtype::Bool)), + _ => Err(integer_ufunc_type_error(name)), + } +} + +/// Converts an integer/boolean scalar, list, or ndarray into an integer ndarray. +/// +/// Scalar values become zero-dimensional arrays so integer ufuncs can use the +/// same broadcasting path as floating-point ufuncs while still rejecting floats. +fn integer_ndarray_or_scalar_from_value( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + if let Ok((data, shape, dtype)) = integer_array_info_with_dtype(value, name, vm) { + Ok(NdArray::new(data, shape, dtype)) + } else { + let (value, dtype) = integer_scalar_info_with_dtype(value, name)?; + Ok(NdArray::new(vec![i64_to_f64(value)], Vec::new(), dtype)) + } +} + +/// Extracts integer ndarray data, accepting lists and rejecting float dtypes. +fn integer_array_info( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(Vec, Vec)> { + let (data, shape, _) = integer_array_info_with_dtype(value, name, vm)?; + Ok((data, shape)) +} + +/// Extracts integer ndarray data and dtype, accepting lists and rejecting float dtypes. +fn integer_array_info_with_dtype( + value: &Value, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult<(Vec, Vec, NdArrayDtype)> { + let (data, shape, dtype) = extract_ndarray_info(value, name, vm)?; + if dtype == NdArrayDtype::Float64 { + Err(integer_ufunc_type_error(name)) + } else { + Ok((data, shape, dtype)) + } +} + +/// Builds a compact TypeError for unsupported integer ufunc inputs. +fn integer_ufunc_type_error(name: &str) -> RunError { + let ufunc = name.strip_prefix("numpy.").unwrap_or(name); + SimpleException::new_msg( + ExcType::TypeError, + format!("ufunc '{ufunc}' not supported for the input types"), + ) + .into() +} + +/// Converts an integer-valued ndarray slot back to `i64`. +#[expect( + clippy::cast_possible_truncation, + reason = "integer ndarray values are represented as f64 in Monty's current ndarray storage" +)] +fn f64_to_i64(value: f64) -> i64 { + value as i64 +} + +/// Converts an `i64` integer result into ndarray backing storage. +#[expect( + clippy::cast_precision_loss, + reason = "integer ndarray values are stored as f64 in Monty's current ndarray model" +)] +fn i64_to_f64(value: i64) -> f64 { + value as f64 +} + +/// Approximation for `numpy.i0()`, the modified Bessel function I0. +/// +/// This uses the classic Cephes polynomial split, which is accurate enough for +/// NumPy-compatible window generation while avoiding a new special-functions +/// dependency in the sandbox runtime. +fn numpy_i0(value: f64) -> f64 { + let x = value.abs(); + if x <= 3.75 { + let y = (x / 3.75).powi(2); + 1.0 + y + * (3.515_622_9 + + y * (3.089_942_4 + y * (1.206_749_2 + y * (0.265_973_2 + y * (0.036_076_8 + y * 0.004_581_3))))) + } else { + let y = 3.75 / x; + (x.exp() / x.sqrt()) + * (0.398_942_28 + + y * (0.013_285_92 + + y * (0.002_253_19 + + y * (-0.001_575_65 + + y * (0.009_162_81 + + y * (-0.020_577_06 + y * (0.026_355_37 + y * (-0.016_476_33 + y * 0.003_923_77)))))))) + } +} + +/// `numpy.frexp()` scalar kernel returning exponent as an integer-valued float. +fn numpy_frexp(value: f64) -> (f64, f64) { + let (mantissa, exponent) = libm::frexp(value); + (mantissa, f64::from(exponent)) +} + +/// `numpy.modf()` scalar kernel. +fn numpy_modf(value: f64) -> (f64, f64) { + libm::modf(value) +} + +/// `numpy.ldexp()` scalar kernel with NumPy-style non-raising overflow behavior. +fn numpy_ldexp(value: f64, exponent: f64) -> f64 { + let exponent = f64_to_i64(exponent); + let exponent = i32::try_from(exponent).unwrap_or(if exponent < 0 { i32::MIN } else { i32::MAX }); + libm::ldexp(value, exponent) +} + +/// `numpy.gcd()` scalar kernel using NumPy's wrapping int64 edge behavior. +fn numpy_gcd(a: i64, b: i64) -> i64 { + wrapping_u64_to_i64(gcd_u64(a.unsigned_abs(), b.unsigned_abs())) +} + +/// `numpy.lcm()` scalar kernel using NumPy's wrapping int64 edge behavior. +fn numpy_lcm(a: i64, b: i64) -> i64 { + if a == 0 || b == 0 { + 0 + } else { + let gcd = gcd_u64(a.unsigned_abs(), b.unsigned_abs()); + wrapping_u64_to_i64((a.unsigned_abs() / gcd).wrapping_mul(b.unsigned_abs())) + } +} + +/// Euclidean GCD for unsigned integer magnitudes. +fn gcd_u64(mut a: u64, mut b: u64) -> u64 { + while b != 0 { + let t = b; + b = a % b; + a = t; + } + a +} + +/// Reinterprets a NumPy int64 ufunc magnitude after two's-complement wrapping. +#[expect( + clippy::cast_possible_wrap, + reason = "NumPy int64 integer ufuncs wrap overflowing unsigned magnitudes into int64" +)] +fn wrapping_u64_to_i64(value: u64) -> i64 { + value as i64 +} + +/// Stable scalar kernel for `numpy.logaddexp()`. +fn numpy_logaddexp(a: f64, b: f64) -> f64 { + if a.is_nan() || b.is_nan() { + f64::NAN + } else { + let max = a.max(b); + if max.is_infinite() { + max + } else { + max + ((a - max).exp() + (b - max).exp()).ln() + } + } +} + +/// Stable scalar kernel for `numpy.logaddexp2()`. +fn numpy_logaddexp2(a: f64, b: f64) -> f64 { + if a.is_nan() || b.is_nan() { + f64::NAN + } else { + let max = a.max(b); + if max.is_infinite() { + max + } else { + max + ((a - max).exp2() + (b - max).exp2()).log2() + } + } +} + +/// Scalar kernel for `numpy.spacing()`. +fn numpy_spacing(value: f64) -> f64 { + if value.is_nan() || value.is_infinite() { + f64::NAN + } else if value == 0.0 { + f64::from_bits(1) + } else { + let direction = if value > 0.0 { f64::INFINITY } else { f64::NEG_INFINITY }; + libm::nextafter(value, direction) - value + } +} + +/// Scalar kernel for `numpy.signbit()` using the f64 backing representation. +fn signbit_as_f64(value: f64) -> f64 { + bool_to_f64(value.is_sign_negative()) +} + +/// Scalar kernel for NumPy's normalized `sinc(x) = sin(pi*x)/(pi*x)`. +fn numpy_sinc(value: f64) -> f64 { + if value == 0.0 { + 1.0 + } else { + let scaled = PI * value; + scaled.sin() / scaled + } +} + +/// Scalar kernel for `numpy.heaviside()`. +fn numpy_heaviside(value: f64, zero_value: f64) -> f64 { + if value.is_nan() { + f64::NAN + } else if value < 0.0 { + 0.0 + } else if value == 0.0 { + zero_value + } else { + 1.0 + } +} + +/// Scalar kernel for `numpy.divmod()`. +fn numpy_divmod(a: f64, b: f64) -> (f64, f64) { + ((a / b).floor(), py_mod(a, b)) +} + +/// Rounds a scalar using the factor computed from NumPy's `decimals` argument. +fn round_to_decimals(value: f64, factor: f64) -> f64 { + (value * factor).round() / factor +} + +// =========================== +// Phase 3+: Additional math, aggregation, logical, manipulation, +// sorting, set, linalg, and creation functions +// =========================== + +/// `numpy.nan_to_num(a)` — replace NaN with 0, inf with large finite, -inf with -large finite. +fn call_nan_to_num(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.nan_to_num", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.nan_to_num", vm)?; + let big = f64::MAX; + let data: Vec = arr + .data() + .iter() + .map(|&v| { + if v.is_nan() { + 0.0 + } else if v == f64::INFINITY { + big + } else if v == f64::NEG_INFINITY { + -big + } else { + v + } + }) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +// --- NaN-aware aggregation helpers --- + +/// Filter NaN values from a slice, returning only finite values. +fn filter_nan(data: &[f64]) -> Vec { + data.iter().copied().filter(|v| !v.is_nan()).collect() +} + +fn nan_sum(data: &[f64]) -> f64 { + filter_nan(data).iter().sum() +} +fn nan_prod(data: &[f64]) -> f64 { + filter_nan(data).iter().fold(1.0, |a, &v| a * v) +} +fn nan_mean(data: &[f64]) -> f64 { + let clean = filter_nan(data); + if clean.is_empty() { + f64::NAN + } else { + clean.iter().sum::() / clean.len() as f64 + } +} +fn nan_min(data: &[f64]) -> f64 { + filter_nan(data).iter().copied().fold(f64::INFINITY, f64::min) +} +fn nan_max(data: &[f64]) -> f64 { + filter_nan(data).iter().copied().fold(f64::NEG_INFINITY, f64::max) +} +fn nan_var(data: &[f64]) -> f64 { + let clean = filter_nan(data); + if clean.is_empty() { + return f64::NAN; + } + let mean = clean.iter().sum::() / clean.len() as f64; + clean.iter().map(|v| (v - mean).powi(2)).sum::() / clean.len() as f64 +} +fn nan_std(data: &[f64]) -> f64 { + nan_var(data).sqrt() +} +fn nan_median(data: &[f64]) -> f64 { + let mut clean = filter_nan(data); + if clean.is_empty() { + return f64::NAN; + } + clean.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + let n = clean.len(); + if n % 2 == 1 { + clean[n / 2] + } else { + f64::midpoint(clean[n / 2 - 1], clean[n / 2]) + } +} + +/// Generic NaN-aware aggregation: extract array, filter NaN, apply function, return float. +fn call_nan_aggregate( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + f: fn(&[f64]) -> f64, + name: &str, +) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, name, vm)?; + Ok(Value::Float(f(arr.data()))) +} + +/// `numpy.nanargmin(a)` — index of minimum, ignoring NaN. +#[expect( + clippy::cast_possible_wrap, + reason = "array indices are small enough that these casts are safe" +)] +fn call_nan_argmin(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.nanargmin", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.nanargmin", vm)?; + let mut best_idx = 0usize; + let mut best_val = f64::INFINITY; + for (i, &v) in arr.data().iter().enumerate() { + if !v.is_nan() && v < best_val { + best_val = v; + best_idx = i; + } + } + Ok(Value::Int(best_idx as i64)) +} + +/// `numpy.nanargmax(a)` — index of maximum, ignoring NaN. +#[expect( + clippy::cast_possible_wrap, + reason = "array indices are small enough that these casts are safe" +)] +fn call_nan_argmax(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.nanargmax", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.nanargmax", vm)?; + let mut best_idx = 0usize; + let mut best_val = f64::NEG_INFINITY; + for (i, &v) in arr.data().iter().enumerate() { + if !v.is_nan() && v > best_val { + best_val = v; + best_idx = i; + } + } + Ok(Value::Int(best_idx as i64)) +} + +/// `numpy.percentile(a, q)` — q-th percentile (q in 0..100). +fn call_percentile(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, q_val) = args.get_two_args("numpy.percentile", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.percentile", vm)?; + let q = to_f64(&q_val, vm)?; + q_val.drop_with_heap(vm); + Ok(Value::Float(percentile_impl(arr.data(), q / 100.0))) +} + +/// `numpy.quantile(a, q)` — q-th quantile (q in 0..1). +fn call_quantile(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, q_val) = args.get_two_args("numpy.quantile", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.quantile", vm)?; + let q = to_f64(&q_val, vm)?; + q_val.drop_with_heap(vm); + Ok(Value::Float(percentile_impl(arr.data(), q))) +} + +/// `numpy.nanpercentile(a, q)` — q-th percentile after dropping NaN values. +fn call_nanpercentile(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, q_val) = args.get_two_args("numpy.nanpercentile", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.nanpercentile", vm)?; + let q = to_f64(&q_val, vm)?; + q_val.drop_with_heap(vm); + let filtered = non_nan_values(arr.data()); + Ok(Value::Float(percentile_impl(&filtered, q / 100.0))) +} + +/// `numpy.nanquantile(a, q)` — q-th quantile after dropping NaN values. +fn call_nanquantile(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, q_val) = args.get_two_args("numpy.nanquantile", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.nanquantile", vm)?; + let q = to_f64(&q_val, vm)?; + q_val.drop_with_heap(vm); + let filtered = non_nan_values(arr.data()); + Ok(Value::Float(percentile_impl(&filtered, q))) +} + +/// `numpy.histogram(a, bins=10)` — one-dimensional histogram counts and bin edges. +fn call_histogram(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr, bins) = histogram_args(vm, args, "numpy.histogram")?; + let (counts, edges) = histogram_counts_edges(arr.data(), bins, vm.heap.tracker())?; + let counts_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + counts, + vec![bins], + NdArrayDtype::Int64, + )))?); + let edges_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + edges, + vec![bins + 1], + NdArrayDtype::Float64, + )))?); + Ok(allocate_tuple(smallvec::smallvec![counts_value, edges_value], vm.heap)?) +} + +/// `numpy.histogram2d(x, y, bins=10)` — two-dimensional histogram counts and edges. +fn call_histogram2d(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (x, y, bins) = histogram2d_args(vm, args, "numpy.histogram2d")?; + let HistogramNdResult { counts, edges, shape } = + histogram_nd_counts_edges(&[x.data(), y.data()], bins, "numpy.histogram2d", vm.heap.tracker())?; + let counts_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + counts, + shape, + NdArrayDtype::Float64, + )))?); + let xedges_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + edges[0].clone(), + vec![bins + 1], + NdArrayDtype::Float64, + )))?); + let yedges_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + edges[1].clone(), + vec![bins + 1], + NdArrayDtype::Float64, + )))?); + Ok(allocate_tuple( + smallvec::smallvec![counts_value, xedges_value, yedges_value], + vm.heap, + )?) +} + +/// `numpy.histogram_bin_edges(a, bins=10)` — return histogram bin edges only. +fn call_histogram_bin_edges(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr, bins) = histogram_args(vm, args, "numpy.histogram_bin_edges")?; + check_array_alloc_size(bins + 1, vm.heap.tracker())?; + let edges = histogram_edges(arr.data(), bins); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + edges, + vec![bins + 1], + NdArrayDtype::Float64, + )))?)) +} + +/// `numpy.histogramdd(sample, bins=10)` — multi-dimensional histogram counts and edge arrays. +fn call_histogramdd(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (sample, bins) = histogram_args(vm, args, "numpy.histogramdd")?; + let sample_axes = histogramdd_sample_axes(&sample)?; + let sample_axis_refs = sample_axes.iter().map(Vec::as_slice).collect::>(); + let HistogramNdResult { counts, edges, shape } = + histogram_nd_counts_edges(&sample_axis_refs, bins, "numpy.histogramdd", vm.heap.tracker())?; + let counts_value = Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + counts, + shape, + NdArrayDtype::Float64, + )))?); + let mut edge_values = Vec::with_capacity(edges.len()); + for edge in edges { + edge_values.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + edge, + vec![bins + 1], + NdArrayDtype::Float64, + )))?)); + } + let edges_value = Value::Ref(vm.heap.allocate(HeapData::List(List::new(edge_values)))?); + Ok(allocate_tuple(smallvec::smallvec![counts_value, edges_value], vm.heap)?) +} + +/// Compute the q-th quantile (q in [0, 1]) using linear interpolation. +#[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "quantile index is always within array bounds" +)] +fn percentile_impl(data: &[f64], q: f64) -> f64 { + if data.is_empty() { + return f64::NAN; + } + let mut sorted: Vec = data.to_vec(); + sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal)); + let n = sorted.len(); + if n == 1 { + return sorted[0]; + } + let idx = q * (n - 1) as f64; + let lo = idx.floor() as usize; + let hi = idx.ceil() as usize; + if lo == hi { + sorted[lo] + } else { + sorted[lo] + (sorted[hi] - sorted[lo]) * (idx - lo as f64) + } +} + +/// Returns a copied vector containing only non-NaN values. +fn non_nan_values(data: &[f64]) -> Vec { + data.iter().copied().filter(|value| !value.is_nan()).collect() +} + +/// Parses common histogram arguments, currently supporting integer `bins`. +fn histogram_args(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, name: &str) -> RunResult<(NdArray, usize)> { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let arr_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 1, 0))?; + defer_drop!(arr_val, vm); + let bins_value = pos.next(); + defer_drop_mut!(bins_value, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most(name, 2, 2 + pos.len())); + } + parse_bins_keyword(kwargs_iter, bins_value, name, vm)?; + + let arr = ndarray_from_value(arr_val, name, vm)?; + let bins = if let Some(value) = bins_value.as_ref() { + histogram_bins_from_value(value, name)? + } else { + 10 + }; + Ok((arr, bins)) +} + +/// Parses the positional arrays and optional integer `bins` for `numpy.histogram2d`. +fn histogram2d_args( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + name: &str, +) -> RunResult<(NdArray, NdArray, usize)> { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + let kwargs_iter = kwargs.into_iter(); + defer_drop_mut!(kwargs_iter, vm); + + let x_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 2, 0))?; + defer_drop!(x_val, vm); + let y_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 2, 1))?; + defer_drop!(y_val, vm); + let bins_value = pos.next(); + defer_drop_mut!(bins_value, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most(name, 3, 3 + pos.len())); + } + parse_bins_keyword(kwargs_iter, bins_value, name, vm)?; + + let x = ndarray_from_value(x_val, name, vm)?; + let y = ndarray_from_value(y_val, name, vm)?; + if x.data().len() == y.data().len() { + let bins = if let Some(value) = bins_value.as_ref() { + histogram_bins_from_value(value, name)? + } else { + 10 + }; + Ok((x, y, bins)) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "x and y must have the same length").into()) + } +} + +/// Parses a supported `bins` value for histogram helpers. +fn histogram_bins_from_value(value: &Value, name: &str) -> RunResult { + let bins = value_to_i64_arg(value, name, "bins")?; + if bins <= 0 { + Err(SimpleException::new_msg(ExcType::ValueError, "`bins` must be positive").into()) + } else { + usize::try_from(bins) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, format!("{name}() bins is too large")).into()) + } +} + +/// Parses an optional `bins` keyword into the shared bins value slot. +fn parse_bins_keyword( + kwargs_iter: &mut impl Iterator, + bins_value: &mut Option, + name: &str, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult<()> { + for (key, value) in kwargs_iter { + defer_drop!(key, vm); + let Some(keyword_name) = key.as_either_str(vm.heap) else { + value.drop_with_heap(vm); + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + if key_str == "bins" { + if bins_value.is_some() { + value.drop_with_heap(vm); + return Err(ExcType::type_error_duplicate_arg(name, key_str)); + } + *bins_value = Some(value); + } else { + value.drop_with_heap(vm); + return Err(ExcType::type_error_unexpected_keyword(name, key_str)); + } + } + Ok(()) +} + +/// Computes histogram counts and edges for finite numeric data. +fn histogram_counts_edges( + data: &[f64], + bins: usize, + tracker: &impl ResourceTracker, +) -> RunResult<(Vec, Vec)> { + check_array_alloc_size(bins, tracker)?; + check_array_alloc_size(bins + 1, tracker)?; + let edges = histogram_edges(data, bins); + let mut counts = vec![0.0; bins]; + for value in data.iter().copied() { + if let Some(index) = histogram_bin_index(value, &edges) { + counts[index] += 1.0; + } + } + Ok((counts, edges)) +} + +/// Bundles an n-dimensional histogram's flat counts, per-axis edges, and output shape. +struct HistogramNdResult { + /// Flat row-major counts buffer for the histogram ndarray. + counts: Vec, + /// One edge vector per sampled dimension. + edges: Vec>, + /// Shape of the counts ndarray, with one bin dimension per sampled axis. + shape: Vec, +} + +/// Computes histogram counts and per-axis edges for same-length sample axes. +fn histogram_nd_counts_edges( + sample_axes: &[&[f64]], + bins: usize, + name: &str, + tracker: &impl ResourceTracker, +) -> RunResult { + if sample_axes.is_empty() { + return Err( + SimpleException::new_msg(ExcType::ValueError, format!("{name}() requires sample dimensions")).into(), + ); + } + let sample_len = sample_axes[0].len(); + if sample_axes.iter().any(|axis| axis.len() != sample_len) { + return Err( + SimpleException::new_msg(ExcType::ValueError, format!("{name}() sample dimensions must match")).into(), + ); + } + + let shape = vec![bins; sample_axes.len()]; + let total_bins = checked_shape_product(&shape, name)?; + check_array_alloc_size(total_bins, tracker)?; + check_array_alloc_size((bins + 1).saturating_mul(sample_axes.len()), tracker)?; + let edges = sample_axes + .iter() + .map(|axis| histogram_edges(axis, bins)) + .collect::>(); + let mut counts = vec![0.0; total_bins]; + for sample_index in 0..sample_len { + let mut coords = Vec::with_capacity(sample_axes.len()); + let mut in_range = true; + for (axis, edge) in sample_axes.iter().zip(edges.iter()) { + if let Some(bin_index) = histogram_bin_index(axis[sample_index], edge) { + coords.push(bin_index); + } else { + in_range = false; + break; + } + } + if in_range { + let flat_index = coords_to_flat_index(&coords, &shape); + counts[flat_index] += 1.0; + } + } + Ok(HistogramNdResult { counts, edges, shape }) +} + +/// Splits a two-dimensional `(n_samples, n_dims)` sample array into per-axis vectors. +fn histogramdd_sample_axes(sample: &NdArray) -> RunResult>> { + match sample.shape() { + [_, 0] => { + Err(SimpleException::new_msg(ExcType::ValueError, "sample must include at least one dimension").into()) + } + [sample_count, dimensions] => { + let mut axes = vec![Vec::with_capacity(*sample_count); *dimensions]; + for sample_index in 0..*sample_count { + let row_start = sample_index * dimensions; + for (dimension, axis) in axes.iter_mut().enumerate() { + axis.push(sample.data()[row_start + dimension]); + } + } + Ok(axes) + } + _ => Err(SimpleException::new_msg(ExcType::ValueError, "sample must be a 2D array").into()), + } +} + +/// Returns the bin index for one value against monotonically increasing edges. +fn histogram_bin_index(value: f64, edges: &[f64]) -> Option { + if !value.is_finite() { + return None; + } + let bins = edges.len().saturating_sub(1); + let first = edges[0]; + let last = edges[bins]; + let width = (last - first) / bins as f64; + if width <= 0.0 || value < first || value > last { + None + } else if matches!(value.total_cmp(&last), Ordering::Equal) { + Some(bins - 1) + } else { + #[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "bin index is checked against the edge range" + )] + let index = ((value - first) / width).floor() as usize; + Some(index.min(bins - 1)) + } +} + +/// Computes evenly spaced histogram bin edges. +fn histogram_edges(data: &[f64], bins: usize) -> Vec { + let finite = data + .iter() + .copied() + .filter(|value| value.is_finite()) + .collect::>(); + let (mut first, mut last) = if finite.is_empty() { + (0.0, 1.0) + } else { + let first = finite.iter().copied().fold(f64::INFINITY, f64::min); + let last = finite.iter().copied().fold(f64::NEG_INFINITY, f64::max); + (first, last) + }; + if matches!(first.total_cmp(&last), Ordering::Equal) { + first -= 0.5; + last += 0.5; + } + let width = (last - first) / bins as f64; + (0..=bins) + .map(|index| { + if index == bins { + last + } else { + first + width * index as f64 + } + }) + .collect() +} + +/// `numpy.ptp(a)` — peak-to-peak: max(a) - min(a). +fn call_ptp(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.ptp", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.ptp", vm)?; + let (min, max) = arr + .data() + .iter() + .fold((f64::INFINITY, f64::NEG_INFINITY), |(mn, mx), &v| { + (mn.min(v), mx.max(v)) + }); + Ok(Value::Float(max - min)) +} + +/// `numpy.cumprod(a)` — cumulative product. +fn call_cumprod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.cumprod", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.cumprod", vm)?; + let mut acc = 1.0; + let data: Vec = arr + .data() + .iter() + .map(|&v| { + acc *= v; + acc + }) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.nancumsum` / `numpy.nancumprod` — cumulative ops treating NaN as identity. +fn call_nancumop(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, is_sum: bool, name: &str) -> RunResult { + let arg = args.get_one_arg(name, vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, name, vm)?; + let identity = if is_sum { 0.0 } else { 1.0 }; + let mut acc = identity; + let data: Vec = arr + .data() + .iter() + .map(|&v| { + let clean = if v.is_nan() { identity } else { v }; + if is_sum { + acc += clean; + } else { + acc *= clean; + } + acc + }) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +// --- Phase 5: Logical and testing --- + +/// Generic logical binary operation on broadcast-compatible inputs. +fn call_logical_binop( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + op: fn(bool, bool) -> bool, + name: &str, +) -> RunResult { + call_numeric_binop( + vm, + args, + |a, b| { + if op(a != 0.0, b != 0.0) { 1.0 } else { 0.0 } + }, + name, + BinopResult::Bool, + ) +} + +/// `numpy.logical_not(a)` — element-wise logical NOT. +fn call_logical_not(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.logical_not", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.logical_not", vm)?; + let data: Vec = arr.data().iter().map(|&v| if v == 0.0 { 1.0 } else { 0.0 }).collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Bool); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.allclose(a, b, rtol=1e-5, atol=1e-8)` — true if all elements are close. +fn call_allclose(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.allclose", vm.heap)?; + defer_drop_mut!(pos, vm); + let a_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.allclose() requires at least 2 arguments"))?; + defer_drop!(a_val, vm); + let b_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.allclose() requires at least 2 arguments"))?; + defer_drop!(b_val, vm); + let rtol = pos + .next() + .map(|v| { + let result = to_f64(&v, vm); + v.drop_with_heap(vm); + result + }) + .transpose()? + .unwrap_or(1e-5); + let atol = pos + .next() + .map(|v| { + let result = to_f64(&v, vm); + v.drop_with_heap(vm); + result + }) + .transpose()? + .unwrap_or(1e-8); + for extra in pos { + extra.drop_with_heap(vm); + } + let a = ndarray_or_scalar_from_value(a_val, "numpy.allclose", vm)?; + let b = ndarray_or_scalar_from_value(b_val, "numpy.allclose", vm)?; + let (left, right, _) = broadcast_pair_data( + a.data(), + a.shape(), + b.data(), + b.shape(), + "numpy.allclose", + vm.heap.tracker(), + )?; + let close = left + .iter() + .zip(right.iter()) + .all(|(&x, &y)| (x - y).abs() <= atol + rtol * y.abs()); + Ok(Value::Bool(close)) +} + +/// `numpy.isclose(a, b, rtol=1e-5, atol=1e-8)` — element-wise closeness test. +fn call_isclose(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.isclose", vm.heap)?; + defer_drop_mut!(pos, vm); + let a_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.isclose() requires at least 2 arguments"))?; + defer_drop!(a_val, vm); + let b_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.isclose() requires at least 2 arguments"))?; + defer_drop!(b_val, vm); + let rtol = pos + .next() + .map(|v| { + let result = to_f64(&v, vm); + v.drop_with_heap(vm); + result + }) + .transpose()? + .unwrap_or(1e-5); + let atol = pos + .next() + .map(|v| { + let result = to_f64(&v, vm); + v.drop_with_heap(vm); + result + }) + .transpose()? + .unwrap_or(1e-8); + for extra in pos { + extra.drop_with_heap(vm); + } + let a = ndarray_or_scalar_from_value(a_val, "numpy.isclose", vm)?; + let b = ndarray_or_scalar_from_value(b_val, "numpy.isclose", vm)?; + let (left, right, shape) = broadcast_pair_data( + a.data(), + a.shape(), + b.data(), + b.shape(), + "numpy.isclose", + vm.heap.tracker(), + )?; + if shape.is_empty() { + Ok(Value::Bool((left[0] - right[0]).abs() <= atol + rtol * right[0].abs())) + } else { + let data: Vec = left + .iter() + .zip(right.iter()) + .map(|(&x, &y)| { + if (x - y).abs() <= atol + rtol * y.abs() { + 1.0 + } else { + 0.0 + } + }) + .collect(); + let result = NdArray::new(data, shape, NdArrayDtype::Bool); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// `numpy.isin(element, test_elements)` — test membership. +fn call_isin(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (elem_val, test_val) = args.get_two_args("numpy.isin", vm.heap)?; + defer_drop!(elem_val, vm); + defer_drop!(test_val, vm); + let elems = ndarray_from_value(elem_val, "numpy.isin", vm)?; + let tests = ndarray_from_value(test_val, "numpy.isin", vm)?; + let test_set: Vec = tests.data().to_vec(); + let data: Vec = elems + .data() + .iter() + .map(|&v| if test_set.contains(&v) { 1.0 } else { 0.0 }) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Bool); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +// --- Phase 6: Manipulation and shape --- + +/// `numpy.flip(a)` — reverse array elements. +fn call_flip(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.flip", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.flip", vm)?; + let mut data = arr.data().to_vec(); + data.reverse(); + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.fliplr(a)` — flip left-right. For 2D: reverse each row. +fn call_fliplr(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.fliplr", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.fliplr", vm)?; + if arr.shape().len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "Input must be >= 2-d.").into()); + } + let cols = arr.shape()[1]; + let mut data = arr.data().to_vec(); + for row in data.chunks_mut(cols) { + row.reverse(); + } + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.flipud(a)` — flip up-down. For 2D: reverse row order. +fn call_flipud(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.flipud", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.flipud", vm)?; + if arr.shape().len() < 2 { + // For 1D, flipud is just reverse + let mut data = arr.data().to_vec(); + data.reverse(); + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + let cols = arr.shape()[1]; + let mut rows: Vec<&[f64]> = arr.data().chunks(cols).collect(); + rows.reverse(); + let data: Vec = rows.into_iter().flatten().copied().collect(); + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.roll(a, shift)` — roll elements by `shift` positions. +#[expect( + clippy::cast_possible_wrap, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "array indices are small enough that these casts are safe" +)] +fn call_roll(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, shift_val) = args.get_two_args("numpy.roll", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.roll", vm)?; + let Value::Int(shift) = &shift_val else { + shift_val.drop_with_heap(vm); + return Err(ExcType::type_error("shift must be integer")); + }; + let shift = *shift; + shift_val.drop_with_heap(vm); + let data = arr.data(); + let n = data.len(); + if n == 0 { + let result = NdArray::new(Vec::new(), vec![0], arr.dtype()); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + let shift = ((shift % n as i64) + n as i64) as usize % n; + let mut new_data = Vec::with_capacity(n); + new_data.extend_from_slice(&data[n - shift..]); + new_data.extend_from_slice(&data[..n - shift]); + let result = NdArray::new(new_data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.expand_dims(a, axis)` — insert a new axis at `axis`. +#[expect( + clippy::cast_possible_wrap, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "array indices are small enough that these casts are safe" +)] +fn call_expand_dims(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, axis_val) = args.get_two_args("numpy.expand_dims", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.expand_dims", vm)?; + let Value::Int(axis) = &axis_val else { + axis_val.drop_with_heap(vm); + return Err(ExcType::type_error("axis must be integer")); + }; + let axis = *axis; + axis_val.drop_with_heap(vm); + let mut shape = arr.shape().to_vec(); + let ndim = shape.len() as i64 + 1; + let axis = if axis < 0 { + (axis + ndim).max(0) as usize + } else { + axis.min(ndim - 1) as usize + }; + shape.insert(axis, 1); + let result = NdArray::new(arr.data().to_vec(), shape, arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.squeeze(a)` — remove length-1 axes. +fn call_squeeze(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.squeeze", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.squeeze", vm)?; + let shape: Vec = arr.shape().iter().copied().filter(|&s| s != 1).collect(); + let shape = if shape.is_empty() { vec![1] } else { shape }; + let result = NdArray::new(arr.data().to_vec(), shape, arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.ravel(a)` — module-level flatten. +fn call_ravel_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.ravel", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.ravel", vm)?; + let len = arr.data().len(); + let result = NdArray::new(arr.data().to_vec(), vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.delete(arr, indices)` — delete elements at given indices. +#[expect( + clippy::cast_possible_wrap, + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "array indices are small enough that these casts are safe" +)] +fn call_delete(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, idx_val) = args.get_two_args("numpy.delete", vm.heap)?; + defer_drop!(arr_val, vm); + defer_drop!(idx_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.delete", vm)?; + let n = arr.data().len(); + // Build set of indices to delete + let del_indices: Vec = if let Value::Int(i) = idx_val { + let i = if *i < 0 { (*i + n as i64) as usize } else { *i as usize }; + vec![i] + } else { + let idx_arr = ndarray_from_value(idx_val, "numpy.delete", vm)?; + idx_arr + .data() + .iter() + .map(|&v| if v < 0.0 { (v + n as f64) as usize } else { v as usize }) + .collect() + }; + let data: Vec = arr + .data() + .iter() + .enumerate() + .filter(|(i, _)| !del_indices.contains(i)) + .map(|(_, &v)| v) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.insert(arr, index, values)` — insert values before the given index. +#[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "array indices are small enough that these casts are safe" +)] +fn call_insert(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.insert", vm.heap)?; + defer_drop_mut!(pos, vm); + let arr_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.insert() requires 3 arguments"))?; + defer_drop!(arr_val, vm); + let idx_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.insert() requires 3 arguments"))?; + defer_drop!(idx_val, vm); + let vals_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.insert() requires 3 arguments"))?; + defer_drop!(vals_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + let arr = ndarray_from_value(arr_val, "numpy.insert", vm)?; + let Value::Int(idx) = idx_val else { + return Err(ExcType::type_error("index must be integer")); + }; + let idx = *idx as usize; + // vals_val can be a scalar or an array + let (vals_data, vals_dtype) = match vals_val { + Value::Float(f) => (vec![*f], NdArrayDtype::Float64), + Value::Int(n) => (vec![*n as f64], NdArrayDtype::Int64), + _ => { + let v = ndarray_from_value(vals_val, "numpy.insert", vm)?; + (v.data().to_vec(), v.dtype()) + } + }; + let mut data = arr.data().to_vec(); + let insert_at = idx.min(data.len()); + for (i, &v) in vals_data.iter().enumerate() { + data.insert(insert_at + i, v); + } + let len = data.len(); + let dtype = promote_dtype(arr.dtype(), vals_dtype); + let result = NdArray::new(data, vec![len], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.diag(v)` — for 1D input: create diagonal matrix. For 2D input: extract diagonal. +fn call_diag(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.diag", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.diag", vm)?; + if arr.shape().len() == 1 { + // Create diagonal matrix + let n = arr.data().len(); + check_array_alloc_size(n * n, vm.heap.tracker())?; + let mut data = vec![0.0; n * n]; + for (i, &v) in arr.data().iter().enumerate() { + data[i * n + i] = v; + } + let result = NdArray::new(data, vec![n, n], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } else { + // Extract diagonal from 2D + let rows = arr.shape()[0]; + let cols = arr.shape()[1]; + let n = rows.min(cols); + let data: Vec = (0..n).map(|i| arr.data()[i * cols + i]).collect(); + let result = NdArray::new(data, vec![n], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// `numpy.diagflat(v, k=0)` — create a diagonal matrix from flattened input. +fn call_diagflat(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg, k_val) = args.get_one_two_args("numpy.diagflat", vm.heap)?; + defer_drop!(arg, vm); + let k = if let Some(k_val) = k_val { + defer_drop!(k_val, vm); + value_to_i64_arg(k_val, "numpy.diagflat", "k")? + } else { + 0 + }; + let arr = ndarray_from_value(arg, "numpy.diagflat", vm)?; + let offset = diagflat_offset(k)?; + let size = arr + .data() + .len() + .checked_add(offset) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.diagflat() dimensions overflow"))?; + let total = size + .checked_mul(size) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.diagflat() dimensions overflow"))?; + check_array_alloc_size(total, vm.heap.tracker())?; + + let mut data = vec![0.0; total]; + for (index, value) in arr.data().iter().enumerate() { + let (row, col) = diagflat_position(index, offset, k); + data[row * size + col] = *value; + } + let result = NdArray::new(data, vec![size, size], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Converts a diagonal offset into a positive matrix-size expansion. +fn diagflat_offset(k: i64) -> RunResult { + usize::try_from(k.unsigned_abs()) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "numpy.diagflat() k is too large").into()) +} + +/// Computes a row/column pair for one flattened input item. +fn diagflat_position(index: usize, offset: usize, k: i64) -> (usize, usize) { + if k >= 0 { + (index, index + offset) + } else { + (index + offset, index) + } +} + +/// `numpy.diagonal(a)` — extract diagonal of 2D array. +fn call_diagonal(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + // For our purposes, same as diag on 2D + call_diag(vm, args) +} + +/// `numpy.trace(a)` — sum of diagonal elements. +fn call_trace(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.trace", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.trace", vm)?; + if arr.shape().len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "trace requires 2-d array").into()); + } + let cols = arr.shape()[1]; + let n = arr.shape()[0].min(cols); + let sum: f64 = (0..n).map(|i| arr.data()[i * cols + i]).sum(); + Ok(Value::Float(sum)) +} + +/// `numpy.flatnonzero(a)` — indices of non-zero elements in flattened array. +fn call_flatnonzero(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.flatnonzero", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.flatnonzero", vm)?; + let data: Vec = arr + .data() + .iter() + .enumerate() + .filter(|&(_, v)| *v != 0.0) + .map(|(i, _)| i as f64) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.asarray(a)` — convert a list or ndarray to an ndarray. +/// +/// Monty does not currently model NumPy views, so ndarray input is copied rather +/// than returned as the identical object. The observable numeric contents, shape, +/// and dtype are preserved for the safe ndarray subset. +fn call_asarray(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.asarray", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.asarray", vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// Compatibility conversion for layout/order helpers that are no-ops in Monty's ndarray model. +fn call_asarray_compat(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.asarray", vm.heap)?; + defer_drop_mut!(pos, vm); + let arg = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.asarray", 1, 0))?; + defer_drop!(arg, vm); + for extra in pos.by_ref() { + extra.drop_with_heap(vm); + } + let arr = ndarray_from_value(arg, "numpy.asarray", vm)?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) +} + +/// `numpy.ix_(*args)` — construct open mesh index arrays from 1-D sequences. +fn call_ix(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.ix_", vm.heap)?; + defer_drop_mut!(pos, vm); + + let mut arrays = Vec::new(); + for arg in pos.by_ref() { + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.ix_", vm)?; + if arr.shape().len() != 1 { + return Err(SimpleException::new_msg(ExcType::ValueError, "Cross index must be 1 dimensional").into()); + } + arrays.push(arr); + } + + let total_len = arrays + .iter() + .try_fold(0usize, |acc, arr| acc.checked_add(arr.data().len())) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.ix_() dimensions overflow"))?; + check_array_alloc_size(total_len, vm.heap.tracker())?; + + let ndim = arrays.len(); + let mut values: SmallVec<[Value; 3]> = SmallVec::new(); + for (axis, arr) in arrays.iter().enumerate() { + let shape = ix_output_shape(axis, arr.data().len(), ndim); + let result = NdArray::new(arr.data().to_vec(), shape, arr.dtype()); + values.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// Computes the broadcastable shape for one `ix_` output array. +fn ix_output_shape(axis: usize, len: usize, ndim: usize) -> Vec { + let mut shape = vec![1; ndim]; + if let Some(dim) = shape.get_mut(axis) { + *dim = len; + } + shape +} + +/// `numpy.mask_indices(n, mask_func, k=0)` — indices selected by a triangular mask. +fn call_mask_indices(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.mask_indices", vm.heap)?; + defer_drop_mut!(pos, vm); + + let n_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.mask_indices", 2, 0))?; + defer_drop!(n_val, vm); + let n = value_to_nonnegative_usize(n_val, "numpy.mask_indices", "n")?; + + let mask_func = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.mask_indices", 2, 1))?; + defer_drop!(mask_func, vm); + let kind = triangle_kind_from_mask_func(mask_func)?; + + let k = if let Some(k_val) = pos.next() { + defer_drop!(k_val, vm); + value_to_i64_arg(k_val, "numpy.mask_indices", "k")? + } else { + 0 + }; + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.mask_indices", 3, 4)); + } + + triangle_indices_tuple(n, k, n, kind, vm) +} + +/// Extracts the supported triangular mask function for `mask_indices()`. +fn triangle_kind_from_mask_func(value: &Value) -> RunResult { + match value { + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Triu)) => Ok(TriangleKind::Upper), + Value::ModuleFunction(ModuleFunctions::Numpy(NumpyFunctions::Tril)) => Ok(TriangleKind::Lower), + _ => Err(ExcType::type_error( + "numpy.mask_indices() only supports numpy.triu or numpy.tril mask functions", + )), + } +} + +/// `numpy.isfortran(a)` — Monty arrays are currently stored only in row-major order. +fn call_isfortran(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.isfortran", vm.heap)?; + arg.drop_with_heap(vm); + Ok(Value::Bool(false)) +} + +/// `numpy.shares_memory()` / `numpy.may_share_memory()` for Monty's copy-based ndarray model. +fn call_memory_overlap(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, name: &str) -> RunResult { + let (a, b) = args.get_two_args(name, vm.heap)?; + defer_drop!(a, vm); + defer_drop!(b, vm); + Ok(Value::Bool(same_ndarray_ref(a, b, vm))) +} + +/// Checks whether both values are the exact same ndarray heap object. +fn same_ndarray_ref(a: &Value, b: &Value, vm: &VM<'_, impl ResourceTracker>) -> bool { + match (a, b) { + (Value::Ref(a_id), Value::Ref(b_id)) if a_id == b_id => matches!(vm.heap.get(*a_id), HeapData::NdArray(_)), + _ => false, + } +} + +/// `numpy.asarray_chkfinite(a)` — convert to array and reject NaN or infinity. +fn call_asarray_chkfinite(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.asarray_chkfinite", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.asarray_chkfinite", vm)?; + if arr.data().iter().all(|value| value.is_finite()) { + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(arr))?)) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "array must not contain infs or NaNs").into()) + } +} + +/// `numpy.column_stack(arrays)` — stack 1D arrays as columns into 2D. +fn call_column_stack(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let list_val = args.get_one_arg("numpy.column_stack", vm.heap)?; + defer_drop!(list_val, vm); + let list_items = match list_val { + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => list + .as_slice() + .iter() + .map(|v| v.clone_with_heap(vm)) + .collect::>(), + _ => return Err(ExcType::type_error("numpy.column_stack() requires a list")), + }, + _ => return Err(ExcType::type_error("numpy.column_stack() requires a list")), + }; + if list_items.is_empty() { + return Err(SimpleException::new_msg(ExcType::ValueError, "need at least one array to stack").into()); + } + // Extract all arrays + let mut arrays: Vec = Vec::new(); + for item in &list_items { + arrays.push(ndarray_from_value(item, "numpy.column_stack", vm)?); + } + for item in list_items { + item.drop_with_heap(vm); + } + let rows = arrays[0].data().len(); + let cols = arrays.len(); + check_array_alloc_size(rows * cols, vm.heap.tracker())?; + let mut data = vec![0.0; rows * cols]; + for (c, arr) in arrays.iter().enumerate() { + for (r, &v) in arr.data().iter().enumerate() { + data[r * cols + c] = v; + } + } + let dtype = arrays + .iter() + .fold(NdArrayDtype::Int64, |d, a| promote_dtype(d, a.dtype())); + let result = NdArray::new(data, vec![rows, cols], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.hsplit(a, n)` — split horizontally (for 1D: split into n parts). +fn call_hsplit(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + // For 1D, hsplit is same as split + call_split(vm, args) +} + +/// `numpy.vsplit(a, n)` — split vertically. +fn call_vsplit(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + call_split(vm, args) +} + +/// `numpy.dsplit(a, indices_or_sections)` — split arrays along depth axis. +fn call_dsplit(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, idx_val) = args.get_two_args("numpy.dsplit", vm.heap)?; + defer_drop!(arr_val, vm); + defer_drop!(idx_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.dsplit", vm)?; + if arr.ndim() < 3 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "dsplit only works on arrays of 3 or more dimensions", + ) + .into()); + } + + let split_indices = split_indices_for_axis(idx_val, arr.shape()[2], "numpy.dsplit", vm)?; + split_ndarray_along_axis_to_list(&arr, 2, &split_indices, vm) +} + +/// Extracts split points for a fixed array axis from an integer or index sequence. +fn split_indices_for_axis( + value: &Value, + axis_len: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + match value { + Value::Int(sections) => equal_split_indices(*sections, axis_len), + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => list + .as_slice() + .iter() + .map(|value| split_index_value_to_usize(value, axis_len, name)) + .collect(), + HeapData::Tuple(tuple) => tuple + .as_slice() + .iter() + .map(|value| split_index_value_to_usize(value, axis_len, name)) + .collect(), + HeapData::NdArray(indices) => indices + .data() + .iter() + .map(|&value| split_index_f64_to_usize(value, axis_len, name)) + .collect(), + _ => Err(ExcType::type_error(format!("{name}() second arg must be int or list"))), + }, + _ => Err(ExcType::type_error(format!("{name}() second arg must be int or list"))), + } +} + +/// Computes split points for equal-sized axis sections. +fn equal_split_indices(sections: i64, axis_len: usize) -> RunResult> { + if sections <= 0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "number sections must be larger than 0").into()); + } + let sections = usize::try_from(sections) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, "number sections is too large"))?; + if !axis_len.is_multiple_of(sections) { + return Err( + SimpleException::new_msg(ExcType::ValueError, "array split does not result in an equal division").into(), + ); + } + let chunk_size = axis_len / sections; + Ok((1..sections).map(|index| index * chunk_size).collect()) +} + +/// Converts one Python split index to a clamped axis offset. +fn split_index_value_to_usize(value: &Value, axis_len: usize, name: &str) -> RunResult { + match value { + Value::Int(index) => split_index_i64_to_usize(*index, axis_len, name), + _ => Err(ExcType::type_error("split indices must be integers")), + } +} + +/// Converts one ndarray-backed split index to a clamped axis offset. +#[expect( + clippy::cast_possible_truncation, + reason = "integer ndarray values are stored as f64 in Monty's current ndarray model" +)] +fn split_index_f64_to_usize(value: f64, axis_len: usize, name: &str) -> RunResult { + if value.is_finite() { + split_index_i64_to_usize(value as i64, axis_len, name) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, format!("{name}() split index must be finite")).into()) + } +} + +/// Converts one signed split index to a NumPy-style clamped axis offset. +fn split_index_i64_to_usize(index: i64, axis_len: usize, name: &str) -> RunResult { + let axis_len = i64::try_from(axis_len) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, format!("{name}() axis is too large")))?; + let resolved = if index < 0 { index + axis_len } else { index }; + usize::try_from(resolved.clamp(0, axis_len)) + .map_err(|_| SimpleException::new_msg(ExcType::ValueError, format!("{name}() split index is too large")).into()) +} + +/// Builds a list of ndarray chunks for a set of split points along one axis. +fn split_ndarray_along_axis_to_list( + arr: &NdArray, + axis: usize, + split_indices: &[usize], + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let mut parts = Vec::new(); + let mut previous = 0; + for &index in split_indices { + let end = index.min(arr.shape()[axis]); + parts.push(axis_chunk_value(arr, axis, previous, end, vm)?); + previous = end; + } + parts.push(axis_chunk_value(arr, axis, previous, arr.shape()[axis], vm)?); + let list = List::new(parts); + Ok(Value::Ref(vm.heap.allocate(HeapData::List(list))?)) +} + +/// Allocates one ndarray chunk for a half-open range along a fixed axis. +fn axis_chunk_value( + arr: &NdArray, + axis: usize, + start_axis: usize, + end_axis: usize, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let data = slice_ndarray_along_axis(arr, axis, start_axis, end_axis); + check_array_alloc_size(data.len(), vm.heap.tracker())?; + let mut shape = arr.shape().to_vec(); + shape[axis] = end_axis.saturating_sub(start_axis); + let chunk = NdArray::new(data, shape, arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(chunk))?)) +} + +/// `numpy.array_split(a, n)` — split into possibly unequal parts. +fn call_array_split(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, n_val) = args.get_two_args("numpy.array_split", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.array_split", vm)?; + let Value::Int(n) = &n_val else { + n_val.drop_with_heap(vm); + return Err(ExcType::type_error("sections must be integer")); + }; + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "sections from user" + )] + let n = *n as usize; + n_val.drop_with_heap(vm); + if n == 0 { + return Err(SimpleException::new_msg(ExcType::ValueError, "number sections must be larger than 0").into()); + } + let data = arr.data(); + let dtype = arr.dtype(); + let total = data.len(); + let base_size = total / n; + let remainder = total % n; + let mut parts = Vec::new(); + let mut offset = 0; + for i in 0..n { + let size = base_size + usize::from(i < remainder); + let chunk = data[offset..offset + size].to_vec(); + let len = chunk.len(); + parts.push(Value::Ref(vm.heap.allocate(HeapData::NdArray(NdArray::new( + chunk, + vec![len], + dtype, + )))?)); + offset += size; + } + let list = List::new(parts); + Ok(Value::Ref(vm.heap.allocate(HeapData::List(list))?)) +} + +/// `numpy.full_like(a, fill_value)` — array of same shape filled with value. +fn call_full_like(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arr_val, fill_val) = args.get_two_args("numpy.full_like", vm.heap)?; + defer_drop!(arr_val, vm); + let arr = ndarray_from_value(arr_val, "numpy.full_like", vm)?; + let (fill, dtype) = match &fill_val { + Value::Int(n) => (*n as f64, NdArrayDtype::Int64), + Value::Float(f) => (*f, NdArrayDtype::Float64), + Value::Bool(b) => (if *b { 1.0 } else { 0.0 }, NdArrayDtype::Bool), + _ => { + fill_val.drop_with_heap(vm); + return Err(ExcType::type_error("fill_value must be a number")); + } + }; + fill_val.drop_with_heap(vm); + let size = arr.data().len(); + let result = NdArray::new(vec![fill; size], arr.shape().to_vec(), dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +// --- Phase 7: Sorting, searching, set ops --- + +/// `numpy.argsort(a)` — module-level argsort. +fn call_argsort_mod(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.argsort", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.argsort", vm)?; + let result_data = argsort_index_data(arr.data()); + let len = result_data.len(); + let result = NdArray::new(result_data, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.argpartition(a, kth)` — deterministic argsort-compatible subset. +fn call_argpartition(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg, kth) = args.get_two_args("numpy.argpartition", vm.heap)?; + defer_drop!(arg, vm); + defer_drop!(kth, vm); + let arr = ndarray_from_value(arg, "numpy.argpartition", vm)?; + validate_partition_kth(kth, arr.data().len(), "numpy.argpartition")?; + let result_data = argsort_index_data(arr.data()); + let len = result_data.len(); + let result = NdArray::new(result_data, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.partition(a, kth)` — deterministic sorted-output subset for 1-D arrays. +fn call_partition(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg, kth) = args.get_two_args("numpy.partition", vm.heap)?; + defer_drop!(arg, vm); + defer_drop!(kth, vm); + let arr = ndarray_from_value(arg, "numpy.partition", vm)?; + validate_partition_kth(kth, arr.data().len(), "numpy.partition")?; + let mut data = arr.data().to_vec(); + data.sort_by(nan_last_cmp); + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.lexsort(keys)` — indirect stable sort using the last key as primary. +fn call_lexsort(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let keys_val = args.get_one_arg("numpy.lexsort", vm.heap)?; + defer_drop!(keys_val, vm); + let key_values = sequence_items(keys_val, "numpy.lexsort", vm)?; + defer_drop!(key_values, vm); + let keys = key_values + .iter() + .map(|value| ndarray_from_value(value, "numpy.lexsort", vm)) + .collect::>>()?; + let Some(first) = keys.first() else { + let result = NdArray::new(Vec::new(), vec![0], NdArrayDtype::Int64); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + }; + let len = first.data().len(); + for key in &keys { + if key.shape().len() != 1 || key.data().len() != len { + return Err(SimpleException::new_msg(ExcType::ValueError, "all keys need to be the same shape").into()); + } + } + check_array_alloc_size(len, vm.heap.tracker())?; + + let mut indices: Vec = (0..len).collect(); + indices.sort_by(|&lhs, &rhs| compare_lexsort_indices(&keys, lhs, rhs)); + let result_data = indices.into_iter().map(usize_to_f64).collect(); + let result = NdArray::new(result_data, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Produces stable argsort indices encoded in Monty's integer ndarray storage. +fn argsort_index_data(data: &[f64]) -> Vec { + let mut indices: Vec = (0..data.len()).collect(); + indices.sort_by(|&a, &b| nan_last_cmp(&data[a], &data[b])); + indices.into_iter().map(usize_to_f64).collect() +} + +/// Validates the scalar `kth` accepted by the supported partition subset. +fn validate_partition_kth(kth: &Value, len: usize, name: &str) -> RunResult<()> { + let kth = value_to_i64_arg(kth, name, "kth")?; + let len_i64 = usize_to_i64(len)?; + let normalized = if kth < 0 { len_i64.saturating_add(kth) } else { kth }; + if normalized < 0 || normalized >= len_i64 { + Err(SimpleException::new_msg(ExcType::ValueError, "kth out of bounds").into()) + } else { + Ok(()) + } +} + +/// Compares two row indices across `lexsort` keys, with later keys taking priority. +fn compare_lexsort_indices(keys: &[NdArray], lhs: usize, rhs: usize) -> Ordering { + for key in keys.iter().rev() { + let ordering = nan_last_cmp(&key.data()[lhs], &key.data()[rhs]); + if ordering != Ordering::Equal { + return ordering; + } + } + lhs.cmp(&rhs) +} + +/// `numpy.cov(m)` — covariance for 1-D or row-wise 2-D real arrays. +fn call_cov(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.cov", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.cov", vm)?; + let (rows, _) = covariance_shape(&arr, "numpy.cov")?; + let data = covariance_matrix_data(&arr, "numpy.cov", vm.heap.tracker())?; + if rows == 1 { + Ok(Value::Float(data[0])) + } else { + let result = NdArray::new(data, vec![rows, rows], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// `numpy.corrcoef(x)` — correlation coefficients for 1-D or row-wise 2-D arrays. +fn call_corrcoef(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.corrcoef", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.corrcoef", vm)?; + let (rows, _) = covariance_shape(&arr, "numpy.corrcoef")?; + let cov = covariance_matrix_data(&arr, "numpy.corrcoef", vm.heap.tracker())?; + if rows == 1 { + Ok(Value::Float(1.0)) + } else { + let mut data = Vec::with_capacity(cov.len()); + for row in 0..rows { + for col in 0..rows { + let denom = (cov[row * rows + row] * cov[col * rows + col]).sqrt(); + data.push(if denom.is_nan() || denom <= 0.0 { + f64::NAN + } else { + cov[row * rows + col] / denom + }); + } + } + let result = NdArray::new(data, vec![rows, rows], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// Returns `(variables, observations)` for covariance-style helpers. +fn covariance_shape(arr: &NdArray, name: &str) -> RunResult<(usize, usize)> { + match arr.shape() { + [cols] => Ok((1, *cols)), + [rows, cols] => Ok((*rows, *cols)), + _ => Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() input has more than 2 dimensions"), + ) + .into()), + } +} + +/// Computes the row-wise sample covariance matrix, matching NumPy's default `bias=False`. +fn covariance_matrix_data(arr: &NdArray, name: &str, tracker: &impl ResourceTracker) -> RunResult> { + let (rows, cols) = covariance_shape(arr, name)?; + check_array_alloc_size(rows.saturating_mul(rows), tracker)?; + let means = (0..rows) + .map(|row| covariance_row_mean(arr, row, cols)) + .collect::>(); + let denom = if cols > 1 { usize_to_f64(cols - 1) } else { f64::NAN }; + let mut data = Vec::with_capacity(rows * rows); + for lhs in 0..rows { + for rhs in 0..rows { + let mut sum = 0.0; + for col in 0..cols { + let lhs_delta = covariance_value(arr, lhs, col, cols) - means[lhs]; + let rhs_delta = covariance_value(arr, rhs, col, cols) - means[rhs]; + sum += lhs_delta * rhs_delta; + } + data.push(sum / denom); + } + } + Ok(data) +} + +/// Computes one variable row mean for covariance-style helpers. +fn covariance_row_mean(arr: &NdArray, row: usize, cols: usize) -> f64 { + let sum = (0..cols).map(|col| covariance_value(arr, row, col, cols)).sum::(); + sum / usize_to_f64(cols) +} + +/// Reads a row/column value from either 1-D or row-wise 2-D covariance input. +fn covariance_value(arr: &NdArray, row: usize, col: usize, cols: usize) -> f64 { + if arr.shape().len() == 1 { + arr.data()[col] + } else { + arr.data()[row * cols + col] + } +} + +/// `numpy.searchsorted(a, v)` — find insertion points for `v` in sorted array `a`. +#[expect( + clippy::cast_possible_wrap, + reason = "array indices are small enough that these casts are safe" +)] +fn call_searchsorted(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, v_val) = args.get_two_args("numpy.searchsorted", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(v_val, vm); + let a = ndarray_from_value(a_val, "numpy.searchsorted", vm)?; + let sorted = a.data(); + // v can be scalar or array + match v_val { + Value::Int(n) => { + let v = *n as f64; + let idx = sorted.partition_point(|&x| x < v); + Ok(Value::Int(idx as i64)) + } + Value::Float(f) => { + let idx = sorted.partition_point(|&x| x < *f); + Ok(Value::Int(idx as i64)) + } + _ => { + let v_arr = ndarray_from_value(v_val, "numpy.searchsorted", vm)?; + let data: Vec = v_arr + .data() + .iter() + .map(|&v| sorted.partition_point(|&x| x < v) as f64) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } + } +} + +/// `numpy.extract(condition, arr)` — extract elements where condition is True. +fn call_extract(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (cond_val, arr_val) = args.get_two_args("numpy.extract", vm.heap)?; + defer_drop!(cond_val, vm); + defer_drop!(arr_val, vm); + let cond = ndarray_from_value(cond_val, "numpy.extract", vm)?; + let arr = ndarray_from_value(arr_val, "numpy.extract", vm)?; + let data: Vec = cond + .data() + .iter() + .zip(arr.data().iter()) + .filter(|(c, _)| **c != 0.0) + .map(|(_, v)| *v) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.trim_zeros(filt, trim='fb')` — trim leading and/or trailing zeros from a 1-D input. +fn call_trim_zeros(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (arg, trim) = args.get_one_two_args("numpy.trim_zeros", vm.heap)?; + defer_drop!(arg, vm); + let trim = if let Some(trim) = trim { + defer_drop!(trim, vm); + string_arg(trim, "numpy.trim_zeros", vm)? + } else { + "fb".to_owned() + }; + let arr = ndarray_from_value(arg, "numpy.trim_zeros", vm)?; + let trim_front = trim.contains('f') || trim.contains('F'); + let trim_back = trim.contains('b') || trim.contains('B'); + let mut start = 0usize; + let mut end = arr.data().len(); + if trim_front { + start = arr.data().iter().position(|value| *value != 0.0).unwrap_or(end); + } + if trim_back { + end = arr + .data() + .iter() + .rposition(|value| *value != 0.0) + .map_or(start, |index| index + 1); + } + if end < start { + end = start; + } + let data = arr.data()[start..end].to_vec(); + let len = data.len(); + let result = NdArray::new(data, vec![len], arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.unwrap(p, discont=None, axis=-1)` over Monty's real ndarray values. +fn call_unwrap(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.unwrap", vm.heap)?; + defer_drop_mut!(pos, vm); + + let arg = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.unwrap", 1, 0))?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.unwrap", vm)?; + if arr.ndim() == 0 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "diff requires input that is at least one dimensional", + ) + .into()); + } + + let discont = if let Some(discont_val) = pos.next() { + defer_drop!(discont_val, vm); + if matches!(discont_val, Value::None) { + None + } else { + Some(to_f64(discont_val, vm)?) + } + } else { + None + }; + if let Some(axis_val) = pos.next() { + defer_drop!(axis_val, vm); + if !matches!(axis_val, Value::None) { + let axis = value_to_i64_arg(axis_val, "numpy.unwrap", "axis")?; + let axis = normalize_axis(axis, arr.ndim(), "numpy.unwrap")?; + if axis != arr.ndim() - 1 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "numpy.unwrap() only supports the last axis", + ) + .into()); + } + } + } + if let Some(extra) = pos.next() { + extra.drop_with_heap(vm); + return Err(ExcType::type_error_at_most("numpy.unwrap", 3, 4)); + } + + check_array_alloc_size(arr.data().len(), vm.heap.tracker())?; + let data = unwrap_phase_values(arr.data(), discont); + let result = NdArray::new(data, arr.shape().to_vec(), NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Computes NumPy-style phase unwrapping with the default `2*pi` period. +fn unwrap_phase_values(values: &[f64], discont: Option) -> Vec { + let Some(&first) = values.first() else { + return Vec::new(); + }; + let period = 2.0 * PI; + let threshold = discont.unwrap_or(PI).max(PI); + let mut output = Vec::with_capacity(values.len()); + output.push(first); + let mut correction = 0.0; + for pair in values.windows(2) { + correction += unwrap_delta_correction(pair[1] - pair[0], threshold, period); + output.push(pair[1] + correction); + } + output +} + +/// Correction needed to map one phase delta into the requested discontinuity interval. +fn unwrap_delta_correction(delta: f64, threshold: f64, period: f64) -> f64 { + if delta.abs() <= threshold { + 0.0 + } else { + let half_period = period / 2.0; + let mut delta_mod = (delta + half_period).rem_euclid(period) - half_period; + if delta_mod.to_bits() == (-half_period).to_bits() && delta > 0.0 { + delta_mod = half_period; + } + delta_mod - delta + } +} + +/// Extracts a string argument from heap or interned string values. +fn string_arg(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + match value { + Value::InternString(id) => Ok(vm.interns.get_str(*id).to_owned()), + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::Str(value) => Ok(value.as_str().to_owned()), + _ => Err(ExcType::type_error(format!("{name}() expected a string argument"))), + }, + _ => Err(ExcType::type_error(format!("{name}() expected a string argument"))), + } +} + +/// Set operation type. +#[derive(Clone, Copy)] +enum SetOp { + Intersect, + Union, + Diff, + Xor, +} + +/// Generic set operation on two sorted-unique arrays. +fn call_set_op(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues, op: SetOp, name: &str) -> RunResult { + let (a_val, b_val) = args.get_two_args(name, vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + let a_arr = ndarray_from_value(a_val, name, vm)?; + let b_arr = ndarray_from_value(b_val, name, vm)?; + let mut a: Vec = a_arr.data().to_vec(); + let mut b: Vec = b_arr.data().to_vec(); + a.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)); + a.dedup(); + b.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)); + b.dedup(); + let data: Vec = match op { + SetOp::Intersect => a.iter().filter(|v| b.contains(v)).copied().collect(), + SetOp::Union => { + let mut u = a.clone(); + for v in &b { + if !u.contains(v) { + u.push(*v); + } + } + u.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)); + u + } + SetOp::Diff => a.iter().filter(|v| !b.contains(v)).copied().collect(), + SetOp::Xor => { + let mut r: Vec = a.iter().filter(|v| !b.contains(v)).copied().collect(); + r.extend(b.iter().filter(|v| !a.contains(v))); + r.sort_by(|x, y| x.partial_cmp(y).unwrap_or(Ordering::Equal)); + r + } + }; + let len = data.len(); + let dtype = promote_dtype(a_arr.dtype(), b_arr.dtype()); + let result = NdArray::new(data, vec![len], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.bincount(a)` — count occurrences of each non-negative integer value. +fn call_bincount(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.bincount", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.bincount", vm)?; + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "index from user data" + )] + let max_val = arr.data().iter().fold(0usize, |m, &v| m.max(v as usize)); + let mut counts = vec![0.0; max_val + 1]; + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "index from user data" + )] + for &v in arr.data() { + counts[v as usize] += 1.0; + } + let len = counts.len(); + let result = NdArray::new(counts, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.digitize(x, bins)` — indices of bins to which each value belongs. +fn call_digitize(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (x_val, bins_val) = args.get_two_args("numpy.digitize", vm.heap)?; + defer_drop!(x_val, vm); + defer_drop!(bins_val, vm); + let x = ndarray_from_value(x_val, "numpy.digitize", vm)?; + let bins = ndarray_from_value(bins_val, "numpy.digitize", vm)?; + let bins_data = bins.data(); + let data: Vec = x + .data() + .iter() + .map(|&v| bins_data.partition_point(|&b| b <= v) as f64) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Int64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +// --- Phase 8: Linear algebra --- + +/// `numpy.outer(a, b)` — outer product of two vectors. +fn call_outer(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.outer", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + let a = ndarray_from_value(a_val, "numpy.outer", vm)?; + let b = ndarray_from_value(b_val, "numpy.outer", vm)?; + let m = a.data().len(); + let n = b.data().len(); + check_array_alloc_size(m * n, vm.heap.tracker())?; + let mut data = Vec::with_capacity(m * n); + for &ai in a.data() { + for &bj in b.data() { + data.push(ai * bj); + } + } + let dtype = promote_dtype(a.dtype(), b.dtype()); + let result = NdArray::new(data, vec![m, n], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.cross(a, b)` — cross product of 3-element vectors. +fn call_cross(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.cross", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + let a = ndarray_from_value(a_val, "numpy.cross", vm)?; + let b = ndarray_from_value(b_val, "numpy.cross", vm)?; + if a.data().len() != 3 || b.data().len() != 3 { + return Err(SimpleException::new_msg(ExcType::ValueError, "cross product requires 3-element vectors").into()); + } + let (a0, a1, a2) = (a.data()[0], a.data()[1], a.data()[2]); + let (b0, b1, b2) = (b.data()[0], b.data()[1], b.data()[2]); + let data = vec![a1 * b2 - a2 * b1, a2 * b0 - a0 * b2, a0 * b1 - a1 * b0]; + let dtype = promote_dtype(a.dtype(), b.dtype()); + let result = NdArray::new(data, vec![3], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.kron(a, b)` — Kronecker product for numeric ndarrays. +fn call_kron(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, b_val) = args.get_two_args("numpy.kron", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(b_val, vm); + let a = ndarray_from_value(a_val, "numpy.kron", vm)?; + let b = ndarray_from_value(b_val, "numpy.kron", vm)?; + let result = kron_arrays(&a, &b, vm.heap.tracker())?; + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.tensordot(a, b, axes=2)` — generalized real-valued tensor contraction. +fn call_tensordot(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + + let a_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.tensordot", 2, 0))?; + defer_drop!(a_val, vm); + let b_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.tensordot", 2, 1))?; + defer_drop!(b_val, vm); + let axes_pos = pos.next(); + defer_drop_mut!(axes_pos, vm); + if pos.len() != 0 { + return Err(ExcType::type_error_at_most("numpy.tensordot", 3, 3 + pos.len())); + } + + let a = ndarray_or_scalar_from_value(a_val, "numpy.tensordot", vm)?; + let b = ndarray_or_scalar_from_value(b_val, "numpy.tensordot", vm)?; + let mut axes = if let Some(value) = axes_pos.as_ref() { + Some(tensordot_axes_from_value(value, a.ndim(), b.ndim(), vm)?) + } else { + None + }; + + let kwargs = kwargs.into_iter(); + defer_drop_mut!(kwargs, vm); + for (key, value) in kwargs { + defer_drop!(key, vm); + defer_drop!(value, vm); + let Some(keyword_name) = key.as_either_str(vm.heap) else { + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + match key_str { + "axes" => { + if axes.is_some() { + return Err(ExcType::type_error_multiple_values("numpy.tensordot", "axes")); + } + axes = Some(tensordot_axes_from_value(value, a.ndim(), b.ndim(), vm)?); + } + _ => return Err(ExcType::type_error_unexpected_keyword("numpy.tensordot", key_str)), + } + } + + let axes = if let Some(axes) = axes { + axes + } else { + tensordot_axes_from_count(2, a.ndim(), b.ndim(), "numpy.tensordot")? + }; + let result = tensordot_arrays(&a, &b, &axes, vm.heap.tracker())?; + if result.shape().is_empty() { + Ok(scalar_from_f64(result.data()[0], result.dtype())) + } else { + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// Axis mapping for a `tensordot` contraction. +struct TensordotAxes { + /// Axes in the left input to contract. + left: Vec, + /// Axes in the right input to contract. + right: Vec, +} + +/// Parses NumPy's `tensordot(axes=...)` argument. +fn tensordot_axes_from_value( + value: &Value, + left_ndim: usize, + right_ndim: usize, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + match value { + Value::Int(count) => { + let count = i64_to_nonnegative_usize(*count, "numpy.tensordot", "axes")?; + tensordot_axes_from_count(count, left_ndim, right_ndim, "numpy.tensordot") + } + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => tensordot_axes_from_pair(items.as_slice(), left_ndim, right_ndim, vm), + HeapData::Tuple(items) => tensordot_axes_from_pair(items.as_slice(), left_ndim, right_ndim, vm), + _ => Err(ExcType::type_error( + "numpy.tensordot() axes must be an integer or pair of axis lists", + )), + }, + _ => Err(ExcType::type_error( + "numpy.tensordot() axes must be an integer or pair of axis lists", + )), + } +} + +/// Builds the default integer-axis contraction used by `tensordot`. +fn tensordot_axes_from_count( + count: usize, + left_ndim: usize, + right_ndim: usize, + name: &str, +) -> RunResult { + if count > left_ndim || count > right_ndim { + Err(SimpleException::new_msg(ExcType::IndexError, "tuple index out of range").into()) + } else { + Ok(TensordotAxes { + left: (left_ndim - count..left_ndim).collect(), + right: (0..count).collect(), + }) + } + .and_then(|axes| validate_tensordot_axes(axes, left_ndim, right_ndim, name)) +} + +/// Parses the two axis specifications accepted by `tensordot`. +fn tensordot_axes_from_pair( + items: &[Value], + left_ndim: usize, + right_ndim: usize, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult { + if items.len() != 2 { + return Err(ExcType::type_error( + "numpy.tensordot() axes pair must contain two entries", + )); + } + let axes = TensordotAxes { + left: tensordot_axis_vector_from_value(&items[0], left_ndim, "numpy.tensordot", vm)?, + right: tensordot_axis_vector_from_value(&items[1], right_ndim, "numpy.tensordot", vm)?, + }; + validate_tensordot_axes(axes, left_ndim, right_ndim, "numpy.tensordot") +} + +/// Parses one side of a `tensordot` axis pair. +fn tensordot_axis_vector_from_value( + value: &Value, + ndim: usize, + name: &str, + vm: &VM<'_, impl ResourceTracker>, +) -> RunResult> { + match value { + Value::Int(axis) => Ok(vec![normalize_axis(*axis, ndim, name)?]), + Value::Ref(heap_id) => match vm.heap.get(*heap_id) { + HeapData::List(items) => axis_sequence_from_items(items.as_slice(), ndim, name, "axes"), + HeapData::Tuple(items) => axis_sequence_from_items(items.as_slice(), ndim, name, "axes"), + _ => Err(ExcType::type_error( + "numpy.tensordot() axes entries must be integers or integer sequences", + )), + }, + _ => Err(ExcType::type_error( + "numpy.tensordot() axes entries must be integers or integer sequences", + )), + } +} + +/// Validates `tensordot` axis arity, uniqueness, and bounds. +fn validate_tensordot_axes( + axes: TensordotAxes, + left_ndim: usize, + right_ndim: usize, + name: &str, +) -> RunResult { + if axes.left.len() != axes.right.len() { + return Err(SimpleException::new_msg(ExcType::ValueError, "shape-mismatch for sum").into()); + } + ensure_unique_axes(&axes.left, name)?; + ensure_unique_axes(&axes.right, name)?; + if axes.left.iter().any(|&axis| axis >= left_ndim) || axes.right.iter().any(|&axis| axis >= right_ndim) { + Err(SimpleException::new_msg(ExcType::IndexError, "tuple index out of range").into()) + } else { + Ok(axes) + } +} + +/// Computes the real-valued `tensordot` contraction using row-major ndarray storage. +fn tensordot_arrays( + left: &NdArray, + right: &NdArray, + axes: &TensordotAxes, + tracker: &impl ResourceTracker, +) -> RunResult { + for (&left_axis, &right_axis) in axes.left.iter().zip(axes.right.iter()) { + if left.shape()[left_axis] != right.shape()[right_axis] { + return Err(SimpleException::new_msg(ExcType::ValueError, "shape-mismatch for sum").into()); + } + } + + let left_outer_axes = complement_axes(left.ndim(), &axes.left); + let right_outer_axes = complement_axes(right.ndim(), &axes.right); + let contract_shape = axes.left.iter().map(|&axis| left.shape()[axis]).collect::>(); + let output_shape = left_outer_axes + .iter() + .map(|&axis| left.shape()[axis]) + .chain(right_outer_axes.iter().map(|&axis| right.shape()[axis])) + .collect::>(); + + let output_len = checked_shape_product(&output_shape, "numpy.tensordot")?; + let contract_len = checked_shape_product(&contract_shape, "numpy.tensordot")?; + check_array_alloc_size(output_len, tracker)?; + + let mut data = Vec::with_capacity(output_len); + for output_flat in 0..output_len { + let output_coords = flat_index_to_coords(output_flat, &output_shape); + let (left_outer_coords, right_outer_coords) = output_coords.split_at(left_outer_axes.len()); + let mut total = 0.0; + for contract_flat in 0..contract_len { + let contract_coords = flat_index_to_coords(contract_flat, &contract_shape); + let left_index = tensordot_input_index( + left.shape(), + &left_outer_axes, + left_outer_coords, + &axes.left, + &contract_coords, + ); + let right_index = tensordot_input_index( + right.shape(), + &right_outer_axes, + right_outer_coords, + &axes.right, + &contract_coords, + ); + total += left.data()[left_index] * right.data()[right_index]; + } + data.push(total); + } + + Ok(NdArray::new( + data, + output_shape, + promote_dtype(left.dtype(), right.dtype()), + )) +} + +/// Returns all axes not selected for contraction, preserving axis order. +fn complement_axes(ndim: usize, selected: &[usize]) -> Vec { + (0..ndim).filter(|axis| !selected.contains(axis)).collect() +} + +/// Builds one row-major flat index from outer and contracted coordinate components. +fn tensordot_input_index( + shape: &[usize], + outer_axes: &[usize], + outer_coords: &[usize], + contract_axes: &[usize], + contract_coords: &[usize], +) -> usize { + let mut coords = vec![0; shape.len()]; + for (&axis, &coord) in outer_axes.iter().zip(outer_coords.iter()) { + coords[axis] = coord; + } + for (&axis, &coord) in contract_axes.iter().zip(contract_coords.iter()) { + coords[axis] = coord; + } + coords_to_flat_index(&coords, shape) +} + +/// `numpy.einsum(subscripts, *operands)` — explicit-subscript real-valued contraction. +fn call_einsum(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (subscripts, operands) = parse_einsum_operands(args, "numpy.einsum", vm)?; + let spec = parse_einsum_spec(&subscripts, &operands, "numpy.einsum")?; + let result = einsum_arrays(&operands, &spec, vm.heap.tracker())?; + if result.shape().is_empty() { + Ok(scalar_from_f64(result.data()[0], result.dtype())) + } else { + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } +} + +/// `numpy.einsum_path(subscripts, *operands)` — simple compatible path result. +fn call_einsum_path(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (subscripts, operands) = parse_einsum_operands(args, "numpy.einsum_path", vm)?; + let spec = parse_einsum_spec(&subscripts, &operands, "numpy.einsum_path")?; + let path = simple_einsum_path(operands.len(), vm)?; + let details = format!( + " Complete contraction: {subscripts}\n Monty contraction path: simple left-to-right\n Result shape: {:?}", + spec.output_shape() + ); + let details = allocate_string(details, vm.heap)?; + allocate_tuple(SmallVec::from_vec(vec![path, details]), vm.heap).map_err(Into::into) +} + +/// Parsed representation of the supported `einsum` subscript subset. +struct EinsumSpec { + /// Label sequence for each input operand. + inputs: Vec>, + /// Output label sequence. + output: Vec, + /// Dimension associated with each label. + label_dims: BTreeMap, + /// Labels reduced by summation. + contracted: Vec, +} + +impl EinsumSpec { + /// Returns the output ndarray shape implied by the output labels. + fn output_shape(&self) -> Vec { + self.output.iter().map(|label| self.label_dims[label]).collect() + } + + /// Returns the shape iterated by the contraction labels. + fn contracted_shape(&self) -> Vec { + self.contracted.iter().map(|label| self.label_dims[label]).collect() + } +} + +/// Parses the common `einsum` call form and accepts but ignores `optimize`. +fn parse_einsum_operands( + args: ArgValues, + name: &'static str, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult<(String, Vec)> { + let (pos, kwargs) = args.into_parts(); + defer_drop_mut!(pos, vm); + let subscripts_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 1, 0))?; + defer_drop!(subscripts_val, vm); + let subscripts = string_from_value(subscripts_val, name, vm)?; + + let mut operands = Vec::new(); + for value in pos { + defer_drop!(value, vm); + operands.push(ndarray_or_scalar_from_value(value, name, vm)?); + } + + let kwargs = kwargs.into_iter(); + defer_drop_mut!(kwargs, vm); + for (key, value) in kwargs { + defer_drop!(key, vm); + defer_drop!(value, vm); + let Some(keyword_name) = key.as_either_str(vm.heap) else { + return Err(ExcType::type_error_kwargs_nonstring_key()); + }; + let key_str = keyword_name.as_str(vm.interns); + if key_str != "optimize" { + return Err(ExcType::type_error_unexpected_keyword(name, key_str)); + } + } + + Ok((subscripts, operands)) +} + +/// Parses labels for Monty's no-ellipsis `einsum` subset. +fn parse_einsum_spec(subscripts: &str, operands: &[NdArray], name: &str) -> RunResult { + let cleaned = subscripts.chars().filter(|ch| !ch.is_whitespace()).collect::(); + if cleaned.contains("...") { + return Err(SimpleException::new_msg(ExcType::NotImplementedError, "einsum ellipsis is not supported").into()); + } + + let parts = cleaned.split("->").collect::>(); + if parts.len() > 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "invalid einsum subscript").into()); + } + let input_specs = parts[0].split(',').collect::>(); + if input_specs.len() != operands.len() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "number of subscripts must match number of operands", + ) + .into()); + } + + let mut inputs = Vec::with_capacity(input_specs.len()); + let mut label_dims = BTreeMap::new(); + let mut label_counts: BTreeMap = BTreeMap::new(); + for (spec, operand) in input_specs.iter().zip(operands.iter()) { + let labels = parse_einsum_labels(spec, name)?; + if labels.len() != operand.ndim() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "operand has more dimensions than subscripts given in einstein sum", + ) + .into()); + } + for (&label, &dim) in labels.iter().zip(operand.shape().iter()) { + match label_dims.get(&label) { + Some(&existing) if existing != dim => { + return Err(SimpleException::new_msg(ExcType::ValueError, "shape-mismatch for sum").into()); + } + Some(_) => {} + None => { + label_dims.insert(label, dim); + } + } + *label_counts.entry(label).or_default() += 1; + } + inputs.push(labels); + } + + let output = if parts.len() == 2 { + let labels = parse_einsum_labels(parts[1], name)?; + ensure_unique_einsum_labels(&labels)?; + for label in &labels { + if !label_dims.contains_key(label) { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "einstein sum subscripts string included output subscript never seen in an input", + ) + .into()); + } + } + labels + } else { + label_counts + .iter() + .filter_map(|(&label, &count)| (count == 1).then_some(label)) + .collect() + }; + let contracted = label_dims + .keys() + .copied() + .filter(|label| !output.contains(label)) + .collect(); + Ok(EinsumSpec { + inputs, + output, + label_dims, + contracted, + }) +} + +/// Parses one comma-separated `einsum` label component. +fn parse_einsum_labels(spec: &str, name: &str) -> RunResult> { + spec.chars() + .map(|ch| { + if ch.is_ascii_alphabetic() { + Ok(ch) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!("{name}() subscripts must use ASCII letters"), + ) + .into()) + } + }) + .collect() +} + +/// Rejects repeated explicit output labels. +fn ensure_unique_einsum_labels(labels: &[char]) -> RunResult<()> { + for (index, label) in labels.iter().enumerate() { + if labels[..index].contains(label) { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "einstein sum subscripts string includes output subscript multiple times", + ) + .into()); + } + } + Ok(()) +} + +/// Executes the supported real-valued `einsum` contraction. +fn einsum_arrays(operands: &[NdArray], spec: &EinsumSpec, tracker: &impl ResourceTracker) -> RunResult { + let output_shape = spec.output_shape(); + let contracted_shape = spec.contracted_shape(); + let output_len = checked_shape_product(&output_shape, "numpy.einsum")?; + let contracted_len = checked_shape_product(&contracted_shape, "numpy.einsum")?; + check_array_alloc_size(output_len, tracker)?; + let dtype = operands + .iter() + .map(NdArray::dtype) + .reduce(promote_dtype) + .unwrap_or(NdArrayDtype::Float64); + + let mut data = Vec::with_capacity(output_len); + for output_flat in 0..output_len { + let output_coords = flat_index_to_coords(output_flat, &output_shape); + let mut total = 0.0; + for contracted_flat in 0..contracted_len { + let contracted_coords = flat_index_to_coords(contracted_flat, &contracted_shape); + let mut product = 1.0; + for (operand, labels) in operands.iter().zip(spec.inputs.iter()) { + let index = einsum_operand_index( + operand.shape(), + labels, + &spec.output, + &output_coords, + &spec.contracted, + &contracted_coords, + ); + product *= operand.data()[index]; + } + total += product; + } + data.push(total); + } + Ok(NdArray::new(data, output_shape, dtype)) +} + +/// Builds one operand index by combining output and contracted label coordinates. +fn einsum_operand_index( + shape: &[usize], + labels: &[char], + output_labels: &[char], + output_coords: &[usize], + contracted_labels: &[char], + contracted_coords: &[usize], +) -> usize { + let coords = labels + .iter() + .map(|label| { + output_labels + .iter() + .position(|output| output == label) + .map(|index| output_coords[index]) + .or_else(|| { + contracted_labels + .iter() + .position(|contracted| contracted == label) + .map(|index| contracted_coords[index]) + }) + .unwrap_or(0) + }) + .collect::>(); + coords_to_flat_index(&coords, shape) +} + +/// Builds a simple left-to-right path list compatible with `numpy.einsum_path`. +fn simple_einsum_path(operand_count: usize, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + let mut items = Vec::new(); + items.push(allocate_string("einsum_path".to_string(), vm.heap)?); + for _ in 1..operand_count { + let pair = SmallVec::from_vec(vec![Value::Int(0), Value::Int(1)]); + items.push(allocate_tuple(pair, vm.heap)?); + } + Ok(Value::Ref(vm.heap.allocate(HeapData::List(List::new(items)))?)) +} + +/// Computes the Kronecker product using NumPy's left-padded shape alignment. +fn kron_arrays(a: &NdArray, b: &NdArray, tracker: &impl ResourceTracker) -> RunResult { + let ndim = a.ndim().max(b.ndim()); + let a_shape = left_padded_shape(a.shape(), ndim); + let b_shape = left_padded_shape(b.shape(), ndim); + let output_shape = a_shape + .iter() + .zip(b_shape.iter()) + .map(|(&lhs, &rhs)| lhs.checked_mul(rhs)) + .collect::>>() + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.kron() dimensions overflow"))?; + let output_len = checked_shape_product(&output_shape, "numpy.kron")?; + check_array_alloc_size(output_len, tracker)?; + + let mut data = vec![0.0; output_len]; + for (a_index, &a_value) in a.data().iter().enumerate() { + let a_coords = flat_index_to_coords(a_index, &a_shape); + for (b_index, &b_value) in b.data().iter().enumerate() { + let b_coords = flat_index_to_coords(b_index, &b_shape); + let output_coords = a_coords + .iter() + .zip(b_coords.iter()) + .zip(b_shape.iter()) + .map(|((&a_coord, &b_coord), &b_dim)| a_coord * b_dim + b_coord) + .collect::>(); + let output_index = coords_to_flat_index(&output_coords, &output_shape); + data[output_index] = a_value * b_value; + } + } + Ok(NdArray::new(data, output_shape, promote_dtype(a.dtype(), b.dtype()))) +} + +/// Left-pads an ndarray shape with ones to participate in NumPy-style shape alignment. +fn left_padded_shape(shape: &[usize], ndim: usize) -> Vec { + let mut padded = vec![1; ndim.saturating_sub(shape.len())]; + padded.extend_from_slice(shape); + padded +} + +/// Converts a row-major flat index to coordinates for a shape. +fn flat_index_to_coords(mut index: usize, shape: &[usize]) -> Vec { + let mut coords = vec![0; shape.len()]; + for axis in (0..shape.len()).rev() { + let dim = shape[axis]; + if dim > 0 { + coords[axis] = index % dim; + index /= dim; + } + } + coords +} + +/// Converts row-major coordinates to a flat index. +fn coords_to_flat_index(coords: &[usize], shape: &[usize]) -> usize { + coords + .iter() + .zip(shape.iter()) + .fold(0usize, |index, (&coord, &dim)| index * dim + coord) +} + +/// `numpy.trapezoid(y, x=None, dx=1.0)` — integrate 1-D samples by the trapezoidal rule. +fn call_trapezoid(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.trapezoid", vm.heap)?; + defer_drop_mut!(pos, vm); + + let y_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.trapezoid", 1, 0))?; + defer_drop!(y_val, vm); + let x_val = pos.next(); + let dx_val = pos.next(); + for extra in pos { + extra.drop_with_heap(vm); + } + + let y = ndarray_from_value(y_val, "numpy.trapezoid", vm)?; + let x = if let Some(x_val) = x_val { + defer_drop!(x_val, vm); + if matches!(x_val, Value::None) { + None + } else { + Some(ndarray_from_value(x_val, "numpy.trapezoid", vm)?) + } + } else { + None + }; + let dx = if let Some(dx_val) = dx_val { + defer_drop!(dx_val, vm); + to_f64(dx_val, vm)? + } else { + 1.0 + }; + + let result = trapezoid_1d(&y, x.as_ref(), dx)?; + Ok(Value::Float(result)) +} + +/// Integrates flattened samples using either explicit x-coordinates or a fixed spacing. +fn trapezoid_1d(y: &NdArray, x: Option<&NdArray>, dx: f64) -> RunResult { + if let Some(x) = x + && x.len() != y.len() + { + return Err( + SimpleException::new_msg(ExcType::ValueError, "numpy.trapezoid() x and y must have same length").into(), + ); + } + + let mut total = 0.0; + for index in 1..y.len() { + let width = x.map_or(dx, |coords| coords.data()[index] - coords.data()[index - 1]); + total += (y.data()[index - 1] + y.data()[index]) * 0.5 * width; + } + Ok(total) +} + +/// `numpy.vander(x, N=None, increasing=False)` — construct a Vandermonde matrix. +fn call_vander(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.vander", vm.heap)?; + defer_drop_mut!(pos, vm); + + let x_val = pos + .next() + .ok_or_else(|| ExcType::type_error_at_least("numpy.vander", 1, 0))?; + defer_drop!(x_val, vm); + let n_val = pos.next(); + let increasing_val = pos.next(); + for extra in pos { + extra.drop_with_heap(vm); + } + + let x = ndarray_from_value(x_val, "numpy.vander", vm)?; + let n = if let Some(n_val) = n_val { + defer_drop!(n_val, vm); + if matches!(n_val, Value::None) { + x.len() + } else { + value_to_nonnegative_usize(n_val, "numpy.vander", "N")? + } + } else { + x.len() + }; + let increasing = if let Some(increasing_val) = increasing_val { + defer_drop!(increasing_val, vm); + value_to_bool_arg(increasing_val, "numpy.vander", "increasing")? + } else { + false + }; + + vander_1d(&x, n, increasing, vm.heap) +} + +/// Builds a Vandermonde matrix for a 1-D numeric input. +fn vander_1d(x: &NdArray, n: usize, increasing: bool, heap: &Heap) -> RunResult { + if x.ndim() == 1 { + let len = x.len(); + check_array_alloc_size(len * n, heap.tracker())?; + let mut data = Vec::with_capacity(len * n); + for &value in x.data() { + for col in 0..n { + let power = if increasing { col } else { n - 1 - col }; + data.push(pow_usize(value, power)); + } + } + let result = NdArray::new(data, vec![len, n], x.dtype()); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, "numpy.vander() x must be a one-dimensional array").into()) + } +} + +/// Raises a base to a non-negative integer exponent without lossy casts. +fn pow_usize(base: f64, exponent: usize) -> f64 { + let mut result = 1.0; + for _ in 0..exponent { + result *= base; + } + result +} + +/// `numpy.poly(seq_of_zeros)` — build descending-power coefficients from real roots. +fn call_poly(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let roots_val = args.get_one_arg("numpy.poly", vm.heap)?; + defer_drop!(roots_val, vm); + let roots = polynomial_1d(roots_val, "numpy.poly", vm)?; + check_array_alloc_size(roots.len() + 1, vm.heap.tracker())?; + let mut coeffs = vec![1.0]; + for &root in roots.data() { + let mut next = vec![0.0; coeffs.len() + 1]; + for (index, &coeff) in coeffs.iter().enumerate() { + next[index] += coeff; + next[index + 1] -= coeff * root; + } + coeffs = next; + } + allocate_polynomial_array(coeffs, NdArrayDtype::Float64, vm) +} + +/// `numpy.polyadd()` / `numpy.polysub()` — combine descending-power coefficient arrays. +fn call_poly_binary( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + name: &str, + operation: impl Fn(f64, f64) -> f64, +) -> RunResult { + let (lhs_val, rhs_val) = args.get_two_args(name, vm.heap)?; + defer_drop!(lhs_val, vm); + defer_drop!(rhs_val, vm); + let lhs = polynomial_1d(lhs_val, name, vm)?; + let rhs = polynomial_1d(rhs_val, name, vm)?; + let len = lhs.len().max(rhs.len()); + check_array_alloc_size(len, vm.heap.tracker())?; + let mut data = Vec::with_capacity(len); + for index in 0..len { + let lhs_value = polynomial_aligned_coeff(&lhs, index, len); + let rhs_value = polynomial_aligned_coeff(&rhs, index, len); + data.push(operation(lhs_value, rhs_value)); + } + allocate_polynomial_array( + trim_leading_zero_coeffs(&data), + promote_dtype(lhs.dtype(), rhs.dtype()), + vm, + ) +} + +/// `numpy.polymul(a, b)` — multiply descending-power coefficient arrays. +fn call_polymul(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (lhs_val, rhs_val) = args.get_two_args("numpy.polymul", vm.heap)?; + defer_drop!(lhs_val, vm); + defer_drop!(rhs_val, vm); + let lhs = polynomial_1d(lhs_val, "numpy.polymul", vm)?; + let rhs = polynomial_1d(rhs_val, "numpy.polymul", vm)?; + let len = lhs + .len() + .checked_add(rhs.len()) + .and_then(|value| value.checked_sub(1)) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "numpy.polymul() coefficients overflow"))?; + check_array_alloc_size(len, vm.heap.tracker())?; + let mut data = vec![0.0; len]; + for (lhs_index, &lhs_value) in lhs.data().iter().enumerate() { + for (rhs_index, &rhs_value) in rhs.data().iter().enumerate() { + data[lhs_index + rhs_index] += lhs_value * rhs_value; + } + } + allocate_polynomial_array( + trim_leading_zero_coeffs(&data), + promote_dtype(lhs.dtype(), rhs.dtype()), + vm, + ) +} + +/// `numpy.polydiv(u, v)` — divide descending-power coefficient arrays. +fn call_polydiv(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (dividend_val, divisor_val) = args.get_two_args("numpy.polydiv", vm.heap)?; + defer_drop!(dividend_val, vm); + defer_drop!(divisor_val, vm); + let dividend = polynomial_1d(dividend_val, "numpy.polydiv", vm)?; + let divisor = polynomial_1d(divisor_val, "numpy.polydiv", vm)?; + let dividend_data = trim_leading_zero_coeffs(dividend.data()); + let divisor_data = trim_leading_zero_coeffs(divisor.data()); + if is_zero_polynomial(&divisor_data) { + Err(SimpleException::new_msg(ExcType::ZeroDivisionError, "polynomial division by zero").into()) + } else if dividend_data.len() < divisor_data.len() { + let quotient = allocate_polynomial_array(vec![0.0], NdArrayDtype::Float64, vm)?; + let remainder = allocate_polynomial_array(dividend_data, NdArrayDtype::Float64, vm)?; + Ok(allocate_tuple(smallvec::smallvec![quotient, remainder], vm.heap)?) + } else { + let quotient_len = dividend_data.len() - divisor_data.len() + 1; + check_array_alloc_size(quotient_len, vm.heap.tracker())?; + check_array_alloc_size(divisor_data.len().saturating_sub(1), vm.heap.tracker())?; + let mut remainder_work = dividend_data; + let mut quotient = vec![0.0; quotient_len]; + for index in 0..quotient_len { + let coeff = remainder_work[index] / divisor_data[0]; + quotient[index] = coeff; + for (divisor_index, &divisor_coeff) in divisor_data.iter().enumerate() { + remainder_work[index + divisor_index] -= coeff * divisor_coeff; + } + } + let remainder_start = quotient_len; + let quotient = allocate_polynomial_array(trim_leading_zero_coeffs("ient), NdArrayDtype::Float64, vm)?; + let remainder = allocate_polynomial_array( + trim_leading_zero_coeffs(&remainder_work[remainder_start..]), + NdArrayDtype::Float64, + vm, + )?; + Ok(allocate_tuple(smallvec::smallvec![quotient, remainder], vm.heap)?) + } +} + +/// `numpy.polyint(p, m=1)` — integrate coefficients repeatedly with zero constants. +fn call_polyint(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (poly, order) = polynomial_unary_args(vm, args, "numpy.polyint")?; + let mut data = poly.data().to_vec(); + for _ in 0..order { + let len = data.len(); + check_array_alloc_size(len + 1, vm.heap.tracker())?; + let mut integrated = Vec::with_capacity(len + 1); + for (index, &coeff) in data.iter().enumerate() { + integrated.push(coeff / usize_to_f64(len - index)); + } + integrated.push(0.0); + data = integrated; + } + allocate_polynomial_array(data, NdArrayDtype::Float64, vm) +} + +/// `numpy.polyder(p, m=1)` — differentiate coefficients repeatedly. +fn call_polyder(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (poly, order) = polynomial_unary_args(vm, args, "numpy.polyder")?; + let mut data = poly.data().to_vec(); + for _ in 0..order { + if data.len() <= 1 { + data = vec![0.0]; + break; + } + let degree = data.len() - 1; + check_array_alloc_size(degree, vm.heap.tracker())?; + let mut derivative = Vec::with_capacity(degree); + for (index, &coeff) in data.iter().take(degree).enumerate() { + derivative.push(coeff * usize_to_f64(degree - index)); + } + data = trim_leading_zero_coeffs(&derivative); + } + allocate_polynomial_array(data, poly.dtype(), vm) +} + +/// `numpy.polyval(p, x)` — evaluate coefficients using Horner's method. +fn call_polyval(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (poly_val, x_val) = args.get_two_args("numpy.polyval", vm.heap)?; + defer_drop!(poly_val, vm); + defer_drop!(x_val, vm); + let poly = polynomial_1d(poly_val, "numpy.polyval", vm)?; + if let Ok((x_data, x_shape, x_dtype)) = extract_ndarray_info(x_val, "numpy.polyval", vm) { + check_array_alloc_size(x_data.len(), vm.heap.tracker())?; + let data = x_data + .iter() + .map(|&value| polynomial_eval(poly.data(), value)) + .collect(); + let result = NdArray::new(data, x_shape, promote_dtype(poly.dtype(), x_dtype)); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) + } else { + let x = to_f64(x_val, vm)?; + Ok(Value::Float(polynomial_eval(poly.data(), x))) + } +} + +/// Parses one coefficient array plus an optional non-negative derivative/integral order. +fn polynomial_unary_args( + vm: &mut VM<'_, impl ResourceTracker>, + args: ArgValues, + name: &str, +) -> RunResult<(NdArray, usize)> { + let pos = args.into_pos_only(name, vm.heap)?; + defer_drop_mut!(pos, vm); + let poly_val = pos.next().ok_or_else(|| ExcType::type_error_at_least(name, 1, 0))?; + defer_drop!(poly_val, vm); + let order_val = pos.next(); + for extra in pos { + extra.drop_with_heap(vm); + } + let poly = polynomial_1d(poly_val, name, vm)?; + let order = if let Some(order_val) = order_val { + defer_drop!(order_val, vm); + value_to_nonnegative_usize(order_val, name, "m")? + } else { + 1 + }; + Ok((poly, order)) +} + +/// Converts one coefficient argument into a one-dimensional ndarray copy. +fn polynomial_1d(value: &Value, name: &str, vm: &VM<'_, impl ResourceTracker>) -> RunResult { + let arr = ndarray_from_value(value, name, vm)?; + if arr.ndim() == 1 { + Ok(arr) + } else { + Err(SimpleException::new_msg(ExcType::ValueError, format!("{name}() expects a 1D coefficient array")).into()) + } +} + +/// Reads a coefficient from a shorter polynomial after aligning by lowest powers. +fn polynomial_aligned_coeff(poly: &NdArray, output_index: usize, output_len: usize) -> f64 { + let offset = output_len.saturating_sub(poly.len()); + if output_index >= offset { + poly.data()[output_index - offset] + } else { + 0.0 + } +} + +/// Allocates a polynomial coefficient vector as a one-dimensional ndarray. +fn allocate_polynomial_array( + data: Vec, + dtype: NdArrayDtype, + vm: &mut VM<'_, impl ResourceTracker>, +) -> RunResult { + let len = data.len(); + let result = NdArray::new(data, vec![len], dtype); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// Removes leading zero coefficients while preserving at least one coefficient. +fn trim_leading_zero_coeffs(data: &[f64]) -> Vec { + let first_non_zero = data + .iter() + .position(|value| !matches!(value.classify(), FpCategory::Zero)) + .unwrap_or_else(|| data.len().saturating_sub(1)); + data[first_non_zero..].to_vec() +} + +/// Returns true when every coefficient is exactly positive or negative zero. +fn is_zero_polynomial(data: &[f64]) -> bool { + data.iter().all(|value| matches!(value.classify(), FpCategory::Zero)) +} + +/// Evaluates descending-power polynomial coefficients for one numeric x value. +fn polynomial_eval(coefficients: &[f64], x: f64) -> f64 { + coefficients.iter().fold(0.0, |acc, &coeff| acc * x + coeff) +} + +/// Converts a Python truth value argument used by NumPy option flags. +fn value_to_bool_arg(value: &Value, name: &str, arg_name: &str) -> RunResult { + match value { + Value::Bool(value) => Ok(*value), + Value::Int(value) => Ok(*value != 0), + _ => Err(ExcType::type_error(format!("{name}() {arg_name} must be a boolean"))), + } +} + +// --- Phase 10: Additional creation and numerical --- + +/// `numpy.logspace(start, stop, num)` — log-spaced values. +#[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "array indices are small enough that these casts are safe" +)] +fn call_logspace(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.logspace", vm.heap)?; + defer_drop_mut!(pos, vm); + let start_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.logspace() requires 3 arguments"))?; + defer_drop!(start_val, vm); + let stop_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.logspace() requires 3 arguments"))?; + defer_drop!(stop_val, vm); + let num_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.logspace() requires 3 arguments"))?; + defer_drop!(num_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + let start = to_f64(start_val, vm)?; + let stop = to_f64(stop_val, vm)?; + let Value::Int(num) = num_val else { + return Err(ExcType::type_error("num must be integer")); + }; + let num = *num as usize; + check_array_alloc_size(num, vm.heap.tracker())?; + // logspace: 10^linspace(start, stop, num) + let data: Vec = if num == 0 { + Vec::new() + } else if num == 1 { + vec![10.0f64.powf(start)] + } else { + let step = (stop - start) / (num - 1) as f64; + (0..num).map(|i| 10.0f64.powf(start + step * i as f64)).collect() + }; + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.geomspace(start, stop, num)` — geometrically spaced values. +#[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "array indices are small enough that these casts are safe" +)] +fn call_geomspace(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.geomspace", vm.heap)?; + defer_drop_mut!(pos, vm); + let start_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.geomspace() requires 3 arguments"))?; + defer_drop!(start_val, vm); + let stop_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.geomspace() requires 3 arguments"))?; + defer_drop!(stop_val, vm); + let num_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.geomspace() requires 3 arguments"))?; + defer_drop!(num_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + let start = to_f64(start_val, vm)?; + let stop = to_f64(stop_val, vm)?; + let Value::Int(num) = num_val else { + return Err(ExcType::type_error("num must be integer")); + }; + let num = *num as usize; + check_array_alloc_size(num, vm.heap.tracker())?; + let data: Vec = if num == 0 { + Vec::new() + } else if num == 1 { + vec![start] + } else { + let log_start = start.ln(); + let log_stop = stop.ln(); + let step = (log_stop - log_start) / (num - 1) as f64; + (0..num).map(|i| (log_start + step * i as f64).exp()).collect() + }; + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.tri(N)` — NxN array with ones at and below diagonal. +fn call_tri(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.tri", vm.heap)?; + let n = extract_size(arg, "numpy.tri", vm)?; + check_array_alloc_size(n * n, vm.heap.tracker())?; + let mut data = vec![0.0; n * n]; + for i in 0..n { + for j in 0..=i { + data[i * n + j] = 1.0; + } + } + let result = NdArray::new(data, vec![n, n], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.tril(m)` — lower triangle of array. +fn call_tril(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.tril", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.tril", vm)?; + if arr.shape().len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "tril requires 2-d array").into()); + } + let rows = arr.shape()[0]; + let cols = arr.shape()[1]; + let mut data = arr.data().to_vec(); + for i in 0..rows { + for j in (i + 1)..cols { + data[i * cols + j] = 0.0; + } + } + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.triu(m)` — upper triangle of array. +fn call_triu(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.triu", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.triu", vm)?; + if arr.shape().len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "triu requires 2-d array").into()); + } + let rows = arr.shape()[0]; + let cols = arr.shape()[1]; + let mut data = arr.data().to_vec(); + for i in 0..rows { + for j in 0..i.min(cols) { + data[i * cols + j] = 0.0; + } + } + let result = NdArray::new(data, arr.shape().to_vec(), arr.dtype()); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.meshgrid(*xi)` — coordinate matrices from coordinate vectors. +fn call_meshgrid(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.meshgrid", vm.heap)?; + defer_drop_mut!(pos, vm); + let mut arrays: Vec = Vec::new(); + for val in pos { + let arr = ndarray_from_value(&val, "numpy.meshgrid", vm)?; + val.drop_with_heap(vm); + arrays.push(arr); + } + if arrays.len() != 2 { + return Err( + SimpleException::new_msg(ExcType::ValueError, "meshgrid currently supports exactly 2 arrays").into(), + ); + } + let x = &arrays[0]; + let y = &arrays[1]; + let nx = x.data().len(); + let ny = y.data().len(); + check_array_alloc_size(nx * ny * 2, vm.heap.tracker())?; + // XX: repeat x for each row + let mut xx_data = Vec::with_capacity(ny * nx); + for _ in 0..ny { + xx_data.extend_from_slice(x.data()); + } + // YY: repeat each y value nx times + let mut yy_data = Vec::with_capacity(ny * nx); + for &yv in y.data() { + for _ in 0..nx { + yy_data.push(yv); + } + } + let dtype = promote_dtype(x.dtype(), y.dtype()); + let xx = NdArray::new(xx_data, vec![ny, nx], dtype); + let yy = NdArray::new(yy_data, vec![ny, nx], dtype); + let xx_val = Value::Ref(vm.heap.allocate(HeapData::NdArray(xx))?); + let yy_val = Value::Ref(vm.heap.allocate(HeapData::NdArray(yy))?); + let values: SmallVec<[Value; 3]> = smallvec::smallvec![xx_val, yy_val]; + allocate_tuple(values, vm.heap).map_err(Into::into) +} + +/// `numpy.gradient(f)` — numerical gradient using central differences. +fn call_gradient(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let arg = args.get_one_arg("numpy.gradient", vm.heap)?; + defer_drop!(arg, vm); + let arr = ndarray_from_value(arg, "numpy.gradient", vm)?; + let data = arr.data(); + let n = data.len(); + if n < 2 { + let result = NdArray::new(vec![0.0; n], vec![n], NdArrayDtype::Float64); + return Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)); + } + let mut grad = Vec::with_capacity(n); + // Forward difference for first element + grad.push(data[1] - data[0]); + // Central differences for interior + for i in 1..n - 1 { + grad.push((data[i + 1] - data[i - 1]) / 2.0); + } + // Backward difference for last element + grad.push(data[n - 1] - data[n - 2]); + let result = NdArray::new(grad, vec![n], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.convolve(a, v)` — discrete linear convolution (mode='full'). +fn call_convolve(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, v_val) = args.get_two_args("numpy.convolve", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(v_val, vm); + let a = ndarray_from_value(a_val, "numpy.convolve", vm)?; + let v = ndarray_from_value(v_val, "numpy.convolve", vm)?; + let na = a.data().len(); + let nv = v.data().len(); + let out_len = na + nv - 1; + check_array_alloc_size(out_len, vm.heap.tracker())?; + let mut result_data = vec![0.0; out_len]; + for i in 0..na { + for j in 0..nv { + result_data[i + j] += a.data()[i] * v.data()[j]; + } + } + let result = NdArray::new(result_data, vec![out_len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.correlate(a, v)` — cross-correlation (mode='valid'). +fn call_correlate(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let (a_val, v_val) = args.get_two_args("numpy.correlate", vm.heap)?; + defer_drop!(a_val, vm); + defer_drop!(v_val, vm); + let a = ndarray_from_value(a_val, "numpy.correlate", vm)?; + let v = ndarray_from_value(v_val, "numpy.correlate", vm)?; + let na = a.data().len(); + let nv = v.data().len(); + if na < nv { + return Err(SimpleException::new_msg(ExcType::ValueError, "a must be at least as long as v").into()); + } + let out_len = na - nv + 1; + let mut result_data = Vec::with_capacity(out_len); + for i in 0..out_len { + let sum: f64 = (0..nv).map(|j| a.data()[i + j] * v.data()[j]).sum(); + result_data.push(sum); + } + let result = NdArray::new(result_data, vec![out_len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.interp(x, xp, fp)` — 1D linear interpolation. +fn call_interp(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.interp", vm.heap)?; + defer_drop_mut!(pos, vm); + let x_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.interp() requires 3 arguments"))?; + defer_drop!(x_val, vm); + let xp_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.interp() requires 3 arguments"))?; + defer_drop!(xp_val, vm); + let fp_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.interp() requires 3 arguments"))?; + defer_drop!(fp_val, vm); + for extra in pos { + extra.drop_with_heap(vm); + } + let x = ndarray_from_value(x_val, "numpy.interp", vm)?; + let xp = ndarray_from_value(xp_val, "numpy.interp", vm)?; + let fp = ndarray_from_value(fp_val, "numpy.interp", vm)?; + let xp_data = xp.data(); + let fp_data = fp.data(); + let data: Vec = x + .data() + .iter() + .map(|&xi| { + if xi <= xp_data[0] { + return fp_data[0]; + } + if xi >= xp_data[xp_data.len() - 1] { + return fp_data[fp_data.len() - 1]; + } + let idx = xp_data.partition_point(|&xv| xv < xi); + if idx == 0 { + return fp_data[0]; + } + let x0 = xp_data[idx - 1]; + let x1 = xp_data[idx]; + let f0 = fp_data[idx - 1]; + let f1 = fp_data[idx]; + f0 + (f1 - f0) * (xi - x0) / (x1 - x0) + }) + .collect(); + let len = data.len(); + let result = NdArray::new(data, vec![len], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} + +/// `numpy.select(condlist, choicelist, default=0)` — conditional selection. +fn call_select(vm: &mut VM<'_, impl ResourceTracker>, args: ArgValues) -> RunResult { + let pos = args.into_pos_only("numpy.select", vm.heap)?; + defer_drop_mut!(pos, vm); + let condlist_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.select() requires 2-3 arguments"))?; + defer_drop!(condlist_val, vm); + let choicelist_val = pos + .next() + .ok_or_else(|| ExcType::type_error("numpy.select() requires 2-3 arguments"))?; + defer_drop!(choicelist_val, vm); + let default_val = pos + .next() + .map(|v| { + let result = to_f64(&v, vm); + v.drop_with_heap(vm); + result + }) + .transpose()? + .unwrap_or(0.0); + for extra in pos { + extra.drop_with_heap(vm); + } + // Extract conditions and choices from lists + let conds: Vec = match condlist_val { + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => { + let items: Vec = list.as_slice().iter().map(|v| v.clone_with_heap(vm)).collect(); + let result: Vec = items + .iter() + .map(|v| ndarray_from_value(v, "numpy.select", vm)) + .collect::>>()?; + for item in items { + item.drop_with_heap(vm); + } + result + } + _ => return Err(ExcType::type_error("condlist must be a list")), + }, + _ => return Err(ExcType::type_error("condlist must be a list")), + }; + let choices: Vec = match choicelist_val { + Value::Ref(id) => match vm.heap.get(*id) { + HeapData::List(list) => { + let items: Vec = list.as_slice().iter().map(|v| v.clone_with_heap(vm)).collect(); + let result: Vec = items + .iter() + .map(|v| ndarray_from_value(v, "numpy.select", vm)) + .collect::>>()?; + for item in items { + item.drop_with_heap(vm); + } + result + } + _ => return Err(ExcType::type_error("choicelist must be a list")), + }, + _ => return Err(ExcType::type_error("choicelist must be a list")), + }; + if conds.is_empty() || conds.len() != choices.len() { + return Err( + SimpleException::new_msg(ExcType::ValueError, "condlist and choicelist must have same length").into(), + ); + } + let n = conds[0].data().len(); + let mut data = vec![default_val; n]; + // Process in reverse order so first matching condition wins + for (cond, choice) in conds.iter().zip(choices.iter()).rev() { + for (i, (&c, &v)) in cond.data().iter().zip(choice.data().iter()).enumerate() { + if c != 0.0 { + data[i] = v; + } + } + } + let result = NdArray::new(data, vec![n], NdArrayDtype::Float64); + Ok(Value::Ref(vm.heap.allocate(HeapData::NdArray(result))?)) +} diff --git a/crates/monty/src/object.rs b/crates/monty/src/object.rs index 4dd01e518..c4a4b65ab 100644 --- a/crates/monty/src/object.rs +++ b/crates/monty/src/object.rs @@ -790,6 +790,9 @@ impl MontyObject { HeapReadOutput::GatherFuture(gather) => { Self::Repr(format!("", gather.get(vm.heap).item_count())) } + HeapReadOutput::NdArray(_) | HeapReadOutput::RePattern(_) | HeapReadOutput::ReMatch(_) => { + repr_or_error(object, vm) + } HeapReadOutput::Path(path) => Self::Path(path.get(vm.heap).as_str().to_owned()), HeapReadOutput::ExtFunction(name) => Self::Function { name: name.get(vm.heap).clone(), diff --git a/crates/monty/src/parse.rs b/crates/monty/src/parse.rs index 116d70b61..958ed420a 100644 --- a/crates/monty/src/parse.rs +++ b/crates/monty/src/parse.rs @@ -173,7 +173,7 @@ pub struct Parser<'a> { impl<'a> Parser<'a> { fn new(code: &'a str, filename: &'a str, mut interner: InternerBuilder) -> Self { - let filename_id = interner.intern(filename); + let filename_id = interner.intern_dynamic(filename); Self { code, filename_id, diff --git a/crates/monty/src/resource.rs b/crates/monty/src/resource.rs index 3906f454b..af6dc9cd1 100644 --- a/crates/monty/src/resource.rs +++ b/crates/monty/src/resource.rs @@ -2,6 +2,7 @@ use std::{ cell::Cell, error::Error, fmt, + mem::size_of, time::{Duration, Instant}, }; @@ -114,6 +115,15 @@ pub fn check_replace_size( check_estimated_size(estimated, tracker) } +/// Pre-checks that creating an ndarray with `num_elements` f64 values won't exceed resource limits. +/// +/// NdArray data is stored as a `Vec`, so the allocation size is `num_elements * 8` bytes. +/// This must be called **before** allocating the `Vec` so that user-controlled sizes +/// (e.g. `np.zeros(10**9)`) are rejected before the Rust heap allocation happens. +pub fn check_array_alloc_size(num_elements: usize, tracker: &impl ResourceTracker) -> Result<(), ResourceError> { + check_estimated_size(num_elements.saturating_mul(size_of::()), tracker) +} + /// Checks an estimated result size against the resource tracker. /// /// Only calls the tracker when the estimate exceeds `LARGE_RESULT_THRESHOLD` diff --git a/crates/monty/src/types/iter.rs b/crates/monty/src/types/iter.rs index 9efe6c276..70473e5de 100644 --- a/crates/monty/src/types/iter.rs +++ b/crates/monty/src/types/iter.rs @@ -80,6 +80,13 @@ impl MontyIter { /// For strings, copies the string content for byte-offset based iteration. /// For ranges, the data is copied so the heap reference is dropped immediately. pub fn new(mut value: Value, vm: &mut VM<'_, impl ResourceTracker>) -> RunResult { + if let Value::Ref(heap_id) = &value + && let HeapData::NdArray(arr) = vm.heap.get(*heap_id) + && arr.shape().is_empty() + { + value.drop_with_heap(vm); + return Err(ExcType::type_error("iteration over a 0-d array")); + } if let Some(iter_value) = IterValue::new(&value, vm) { // For Range, we copy next/step/len into ForIterValue::Range, so we don't need // to keep the heap object alive during iteration. Drop it immediately to avoid @@ -400,6 +407,11 @@ fn get_heap_item( .expect("index should be valid") .clone_with_heap(vm), )), + HeapData::NdArray(arr) => { + // For 1D arrays, yield scalars; for multi-dimensional arrays, yield sub-arrays (rows). + #[expect(clippy::cast_possible_wrap, reason = "index won't exceed i64::MAX")] + Ok(Some(arr.getitem_int(index as i64, vm.heap)?)) + } _ => panic!("get_heap_item: unexpected heap data type"), } } @@ -593,6 +605,18 @@ impl IterValue { len: Some(set.len()), checks_mutation: true, }), + // NdArray: iterate over first dimension (scalars for 1D, sub-arrays otherwise) + HeapData::NdArray(arr) => { + if arr.shape().is_empty() { + None + } else { + Some(Self::HeapRef { + heap_id, + len: Some(arr.shape()[0]), + checks_mutation: false, + }) + } + } // String: copy content for iteration HeapData::Str(s) => Some(Self::from_str(s.as_str())), // Range: copy values for iteration diff --git a/crates/monty/src/types/mod.rs b/crates/monty/src/types/mod.rs index 1f9b8024e..1b8df40e8 100644 --- a/crates/monty/src/types/mod.rs +++ b/crates/monty/src/types/mod.rs @@ -16,6 +16,7 @@ pub mod list; pub mod long_int; pub mod module; pub mod namedtuple; +pub mod ndarray; pub mod path; pub mod property; pub mod py_trait; @@ -39,6 +40,7 @@ pub(crate) use list::List; pub(crate) use long_int::LongInt; pub(crate) use module::Module; pub(crate) use namedtuple::NamedTuple; +pub(crate) use ndarray::NdArray; pub(crate) use path::Path; pub(crate) use property::Property; pub(crate) use py_trait::{AttrCallResult, PyTrait}; diff --git a/crates/monty/src/types/ndarray.rs b/crates/monty/src/types/ndarray.rs new file mode 100644 index 000000000..eac24457e --- /dev/null +++ b/crates/monty/src/types/ndarray.rs @@ -0,0 +1,2578 @@ +//! NumPy-compatible ndarray type for the Monty interpreter. +//! +//! Provides a multi-dimensional array of f64 values that emulates the subset of +//! `numpy.ndarray` commonly used by LLMs. Backed by a flat `Vec` with shape +//! metadata, supporting element-wise arithmetic, comparisons, indexing, and +//! aggregation methods. +//! +//! This is a built-in type (like `list` or `dict`) rather than a user-defined class, +//! so operator overloading and method dispatch are hardcoded in the VM — no class +//! support is required. +//! +//! # Supported operations +//! +//! - Element-wise arithmetic: `+`, `-`, `*`, `/`, `//`, `%`, `**`, unary `-` +//! - NumPy-style broadcasting: `arr + 5`, `matrix + vector`, singleton dimensions +//! - Comparisons: `>`, `<`, `==`, `>=`, `<=`, `!=` (return boolean arrays) +//! - Boolean indexing: `arr[arr > 3]` +//! - Integer indexing: `arr[0]`, `arr[1][2]` for 2D +//! - Aggregation: `sum()`, `mean()`, `min()`, `max()`, `std()` +//! - Shape manipulation: `reshape()`, `flatten()` +//! - Element-wise transforms: `cumsum()`, `cumprod()`, `abs()`, `round()`, `clip()`, `sort()` +//! - Selection: `take()`, `compress()`, `diagonal()`, `item()`, `squeeze()` +//! - In-place: `fill()` +//! - Linear algebra: `trace()`, `swapaxes()` +//! - Conversion: `tolist()` +//! - Attributes: `.shape`, `.dtype`, `.size`, `.ndim`, `.T`, `.nbytes`, `.itemsize` + +use std::{ + cmp::Ordering, + fmt::{self, Write}, + mem::size_of, + string::ToString, +}; + +use ahash::AHashSet; +use smallvec::{SmallVec, smallvec}; + +use crate::{ + args::ArgValues, + bytecode::{CallResult, VM}, + defer_drop, + exception_private::{ExcType, RunResult, SimpleException}, + heap::{DropWithHeap, Heap, HeapData, HeapId, HeapItem, HeapRead}, + intern::StaticStrings, + resource::{ResourceError, ResourceTracker, check_array_alloc_size}, + types::{List, PyTrait, Slice, Str, Type, allocate_tuple}, + value::{EitherStr, Value}, +}; + +/// The element type stored in an ndarray. +/// +/// NumPy arrays have a dtype that determines how elements are stored and displayed. +/// We support the two most common dtypes: 64-bit integers and 64-bit floats. +/// Boolean arrays (from comparisons) use `Bool`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub(crate) enum NdArrayDtype { + /// 64-bit signed integer (`int64` / `numpy.int64`). + Int64, + /// 64-bit floating point (`float64` / `numpy.float64`). + Float64, + /// Boolean array (`bool` / `numpy.bool_`), used for comparison results and masks. + Bool, +} + +impl fmt::Display for NdArrayDtype { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Int64 => f.write_str("int64"), + Self::Float64 => f.write_str("float64"), + Self::Bool => f.write_str("bool"), + } + } +} + +/// A multi-dimensional array of numeric values, emulating `numpy.ndarray`. +/// +/// Data is stored as a flat `Vec` with shape metadata. Even integer arrays +/// store values as f64 internally — the `dtype` field controls display formatting +/// (integers show without decimal points) and type promotion rules. +/// +/// Boolean arrays store 0.0 for `False` and 1.0 for `True`. +/// +/// # Memory layout +/// +/// Row-major (C-contiguous) order, matching NumPy's default. A 2D array with +/// shape `(3, 2)` stores elements as `[row0_col0, row0_col1, row1_col0, ...]`. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub(crate) struct NdArray { + /// Flat storage of all elements in row-major order. + pub(crate) data: Vec, + /// Dimensions of the array (e.g., `[3]` for 1D, `[2, 3]` for 2D). + pub(crate) shape: Vec, + /// Element type, controlling display format and type promotion. + pub(crate) dtype: NdArrayDtype, + /// Whether this array is the materialized backing object for `ndarray.flat`. + /// + /// Monty represents `flatiter` with the existing ndarray storage so the hot + /// `Value::Ref` path remains unchanged for non-NumPy workloads. This marker + /// only affects the reported Python type. + #[serde(default)] + is_flatiter: bool, +} + +// =========================== +// Public constructors and accessors +// =========================== + +impl NdArray { + /// Creates a new ndarray from flat data with the given shape and dtype. + /// + /// The caller must ensure `data.len() == shape.iter().product()`. + pub fn new(data: Vec, shape: Vec, dtype: NdArrayDtype) -> Self { + debug_assert_eq!( + data.len(), + shape.iter().product::(), + "data length must match shape product" + ); + Self { + data, + shape, + dtype, + is_flatiter: false, + } + } + + /// Marks a materialized one-dimensional array as the public `numpy.flatiter` result. + /// + /// The backing behavior intentionally stays ndarray-like because Monty does not + /// yet have view objects, but `type(arr.flat)` and `isinstance(arr.flat, + /// np.flatiter)` can distinguish it from a normal ndarray. + #[must_use] + pub fn into_flatiter(mut self) -> Self { + self.is_flatiter = true; + self + } + + /// Returns the total number of elements in the array. + pub fn len(&self) -> usize { + self.data.len() + } + + /// Returns a reference to the raw f64 data backing this array. + pub fn data(&self) -> &[f64] { + &self.data + } + + /// Returns the shape as a slice of dimensions. + pub fn shape(&self) -> &[usize] { + &self.shape + } + + /// Returns the dtype of the array. + pub fn dtype(&self) -> NdArrayDtype { + self.dtype + } + + /// Returns the number of dimensions (ndim). + pub fn ndim(&self) -> usize { + self.shape.len() + } +} + +/// Computes the shared broadcast shape for NumPy-compatible array operands. +/// +/// Shapes are aligned from the right. Each aligned dimension is compatible when +/// the dimensions are equal or either side is `1`; otherwise the operands cannot +/// be broadcast. A zero dimension can broadcast with `1`, preserving NumPy's +/// empty-result behavior for shapes such as `(0,)` and `(1,)`. +pub(crate) fn broadcast_shape(shapes: &[&[usize]], _name: &str) -> RunResult> { + let ndim = shapes.iter().map(|shape| shape.len()).max().unwrap_or(0); + let mut result = vec![1; ndim]; + for shape in shapes { + let offset = ndim - shape.len(); + for (axis, &dim) in shape.iter().enumerate() { + let result_dim = &mut result[offset + axis]; + if *result_dim == 1 { + *result_dim = dim; + } else if dim != 1 && dim != *result_dim { + return Err(SimpleException::new_msg(ExcType::ValueError, broadcast_error_message(shapes)).into()); + } + } + } + Ok(result) +} + +/// Materializes an array's data in a target broadcast shape. +/// +/// NumPy represents most broadcasts as views, but Monty's ndarray stores +/// contiguous owned data. This helper therefore expands singleton dimensions +/// into an owned `Vec` while checking the projected allocation against the +/// active resource limits before reserving memory. +pub(crate) fn broadcast_array_data( + data: &[f64], + from_shape: &[usize], + to_shape: &[usize], + name: &str, + tracker: &impl ResourceTracker, +) -> RunResult> { + let actual_shape = broadcast_shape(&[from_shape, to_shape], name)?; + if actual_shape != to_shape { + return Err( + SimpleException::new_msg(ExcType::ValueError, broadcast_error_message(&[from_shape, to_shape])).into(), + ); + } + let total = checked_shape_product(to_shape)?; + check_array_alloc_size(total, tracker)?; + if from_shape == to_shape { + Ok(data.to_vec()) + } else if total == 0 { + Ok(Vec::new()) + } else if from_shape.is_empty() { + Ok(vec![data[0]; total]) + } else { + let input_strides = row_major_strides(from_shape); + let output_strides = row_major_strides(to_shape); + let offset = to_shape.len() - from_shape.len(); + let mut output = Vec::with_capacity(total); + for flat_index in 0..total { + let coords = coords_from_flat_index(flat_index, to_shape, &output_strides); + let mut input_index = 0; + for (axis, (&dim, &stride)) in from_shape.iter().zip(input_strides.iter()).enumerate() { + let coord = if dim == 1 { 0 } else { coords[offset + axis] }; + input_index += coord * stride; + } + output.push(data[input_index]); + } + Ok(output) + } +} + +/// Broadcasts a pair of arrays and returns materialized data plus the result shape. +/// +/// This is used by ndarray operators and NumPy ufunc-style functions so they +/// share the same shape rules, allocation checks, and mismatch messages. +pub(crate) fn broadcast_pair_data( + left_data: &[f64], + left_shape: &[usize], + right_data: &[f64], + right_shape: &[usize], + name: &str, + tracker: &impl ResourceTracker, +) -> RunResult<(Vec, Vec, Vec)> { + let shape = broadcast_shape(&[left_shape, right_shape], name)?; + let total = checked_shape_product(&shape)?; + // Monty materializes both broadcast operands and the final result as owned + // vectors, so the pre-check needs to cover peak storage rather than just + // the two expanded inputs. + check_array_alloc_size(total.saturating_mul(3), tracker)?; + let left = broadcast_array_data(left_data, left_shape, &shape, name, tracker)?; + let right = broadcast_array_data(right_data, right_shape, &shape, name, tracker)?; + Ok((left, right, shape)) +} + +/// Returns the row-major strides for a shape. +fn row_major_strides(shape: &[usize]) -> Vec { + let mut strides = vec![1; shape.len()]; + let mut stride = 1usize; + for axis in (0..shape.len()).rev() { + strides[axis] = stride; + stride = stride.saturating_mul(shape[axis]); + } + strides +} + +/// Converts a flat row-major index into coordinates for the given shape. +fn coords_from_flat_index(flat_index: usize, shape: &[usize], strides: &[usize]) -> Vec { + shape + .iter() + .zip(strides.iter()) + .map(|(&dim, &stride)| { + if dim == 0 || stride == 0 { + 0 + } else { + (flat_index / stride) % dim + } + }) + .collect() +} + +/// Computes a shape product while turning overflow into a Python exception. +fn checked_shape_product(shape: &[usize]) -> RunResult { + shape.iter().try_fold(1usize, |acc, &dim| { + acc.checked_mul(dim) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "broadcast dimensions overflow").into()) + }) +} + +/// Formats NumPy's compact two-operand broadcast mismatch message. +fn broadcast_error_message(shapes: &[&[usize]]) -> String { + if shapes.len() == 2 { + format!( + "operands could not be broadcast together with shapes {} {} ", + format_broadcast_shape(shapes[0]), + format_broadcast_shape(shapes[1]) + ) + } else { + "operands could not be broadcast together".to_string() + } +} + +/// Formats a shape the way NumPy displays it inside ufunc broadcast errors. +fn format_broadcast_shape(shape: &[usize]) -> String { + match shape { + [] => "()".to_string(), + [dim] => format!("({dim},)"), + _ => { + let mut formatted = String::from("("); + for (index, dim) in shape.iter().enumerate() { + if index > 0 { + formatted.push(','); + } + write!(formatted, "{dim}").expect("writing to a string cannot fail"); + } + formatted.push(')'); + formatted + } + } +} + +// =========================== +// Indexing operations +// =========================== + +impl NdArray { + /// Indexes a 1D array by integer, returning a scalar Value. + /// + /// For multi-dimensional arrays, returns a sub-array (row slice). + pub fn getitem_int(&self, index: i64, heap: &Heap) -> RunResult { + if self.ndim() == 1 { + let idx = resolve_index(index, self.shape[0])?; + Ok(self.element_to_value(self.data[idx])) + } else { + // For multi-dimensional arrays, return a sub-array (row) + let idx = resolve_index(index, self.shape[0])?; + let row_size: usize = self.shape[1..].iter().product(); + let start = idx * row_size; + let end = start + row_size; + let row_data = self.data[start..end].to_vec(); + let row_shape = self.shape[1..].to_vec(); + let row = Self::new(row_data, row_shape, self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(row))?)) + } + } + + /// Indexes by a boolean mask array, returning elements where mask is true. + pub fn getitem_bool_mask(&self, mask: &Self, heap: &Heap) -> RunResult { + if mask.len() != self.len() { + return Err( + SimpleException::new_msg(ExcType::IndexError, "boolean index did not match indexed array").into(), + ); + } + let filtered: Vec = self + .data + .iter() + .zip(mask.data.iter()) + .filter(|(_, m)| **m != 0.0) + .map(|(v, _)| *v) + .collect(); + let len = filtered.len(); + let result = Self::new(filtered, vec![len], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } + + /// Indexes by an integer ndarray (fancy indexing), gathering elements at the specified indices. + pub fn getitem_int_array(&self, idx_arr: &Self, heap: &Heap) -> RunResult { + if idx_arr.dtype() != NdArrayDtype::Int64 { + return Err(SimpleException::new_msg( + ExcType::IndexError, + "arrays used as indices must be of integer (or boolean) type", + ) + .into()); + } + if self.shape.is_empty() { + return Err(SimpleException::new_msg( + ExcType::IndexError, + "too many indices for array: array is 0-dimensional, but 1 were indexed", + ) + .into()); + } + + if self.ndim() == 1 && idx_arr.shape().is_empty() { + #[expect(clippy::cast_possible_truncation, reason = "integer ndarray stores indices as f64")] + let idx = idx_arr.data()[0] as i64; + let resolved = resolve_index(idx, self.shape[0])?; + return Ok(self.element_to_value(self.data[resolved])); + } + + let row_size = self.shape[1..].iter().product::(); + let mut result_shape = idx_arr.shape().to_vec(); + result_shape.extend_from_slice(&self.shape[1..]); + let result_len = checked_shape_product(&result_shape)?; + check_array_alloc_size(result_len, heap.tracker())?; + let mut data = Vec::with_capacity(result_len); + for &idx_f in idx_arr.data() { + #[expect(clippy::cast_possible_truncation, reason = "index from f64")] + let idx = idx_f as i64; + let resolved = resolve_index(idx, self.shape[0])?; + let start = resolved * row_size; + let end = start + row_size; + data.extend_from_slice(&self.data[start..end]); + } + let result = Self::new(data, result_shape, self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } + + /// Indexes by a Python slice object (e.g. `arr[1:3]`, `arr[::2]`, `arr[::-1]`). + pub fn getitem_slice(&self, slice: &Slice, heap: &Heap) -> RunResult { + let len = self.data.len(); + let (start, stop, step) = slice.indices(len)?; + + let mut data = Vec::new(); + if step > 0 { + let mut i = start; + while i < stop { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "positive-step slice indices are clamped to the array bounds" + )] + { + data.push(self.data[i as usize]); + } + i += step; + } + } else { + let mut i = start; + while i > stop { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "negative-step slice indices visited here are in bounds" + )] + { + data.push(self.data[i as usize]); + } + i += step; + } + } + + let result_len = data.len(); + let result = Self::new(data, vec![result_len], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(result))?)) + } + + /// Converts a single f64 element to the appropriate Value based on dtype. + fn element_to_value(&self, val: f64) -> Value { + match self.dtype { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is the intended int conversion" + )] + NdArrayDtype::Int64 => Value::Int(val as i64), + NdArrayDtype::Float64 => Value::Float(val), + NdArrayDtype::Bool => Value::Bool(val != 0.0), + } + } +} + +// =========================== +// Element-wise binary and comparison operations +// =========================== + +impl NdArray { + /// Element-wise binary operation between two broadcast-compatible arrays. + fn elementwise_op( + &self, + other: &Self, + op: fn(f64, f64) -> f64, + heap: &Heap, + ) -> RunResult { + let result_dtype = promote_dtype(self.dtype, other.dtype); + let (left, right, shape) = broadcast_pair_data( + &self.data, + &self.shape, + &other.data, + &other.shape, + "ndarray binary operation", + heap.tracker(), + )?; + let data: Vec = left.iter().zip(right.iter()).map(|(&a, &b)| op(a, b)).collect(); + let arr = Self::new(data, shape, result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Element-wise operation with a scalar on the right. + fn scalar_op_right( + &self, + scalar: f64, + op: fn(f64, f64) -> f64, + result_dtype: NdArrayDtype, + heap: &Heap, + ) -> RunResult { + check_array_alloc_size(self.len(), heap.tracker())?; + let data: Vec = self.data.iter().map(|&a| op(a, scalar)).collect(); + let arr = Self::new(data, self.shape.clone(), result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Element-wise operation with a scalar on the left (scalar op array). + /// + /// Used for non-commutative operations like `5 - arr` or `10 / arr`. + fn scalar_op_left( + &self, + scalar: f64, + op: fn(f64, f64) -> f64, + result_dtype: NdArrayDtype, + heap: &Heap, + ) -> RunResult { + check_array_alloc_size(self.len(), heap.tracker())?; + let data: Vec = self.data.iter().map(|&a| op(scalar, a)).collect(); + let arr = Self::new(data, self.shape.clone(), result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Element-wise comparison between two broadcast-compatible arrays. + fn elementwise_cmp( + &self, + other: &Self, + cmp: fn(f64, f64) -> bool, + heap: &Heap, + ) -> RunResult { + let (left, right, shape) = broadcast_pair_data( + &self.data, + &self.shape, + &other.data, + &other.shape, + "ndarray comparison", + heap.tracker(), + )?; + let data: Vec = left + .iter() + .zip(right.iter()) + .map(|(&a, &b)| if cmp(a, b) { 1.0 } else { 0.0 }) + .collect(); + let arr = Self::new(data, shape, NdArrayDtype::Bool); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Scalar comparison, producing a boolean array. + fn scalar_cmp( + &self, + scalar: f64, + cmp: fn(f64, f64) -> bool, + heap: &Heap, + ) -> RunResult { + check_array_alloc_size(self.len(), heap.tracker())?; + let data: Vec = self + .data + .iter() + .map(|&a| if cmp(a, scalar) { 1.0 } else { 0.0 }) + .collect(); + let arr = Self::new(data, self.shape.clone(), NdArrayDtype::Bool); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Element-wise addition with another array. + pub fn add(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_op(other, |a, b| a + b, heap) + } + + /// Addition with a scalar value. + pub fn add_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_right(scalar, |a, b| a + b, dtype, heap) + } + + /// Element-wise subtraction. + pub fn sub(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_op(other, |a, b| a - b, heap) + } + + /// Subtraction with a scalar value. + pub fn sub_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_right(scalar, |a, b| a - b, dtype, heap) + } + + /// Element-wise multiplication. + pub fn mul(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_op(other, |a, b| a * b, heap) + } + + /// Multiplication with a scalar value. + pub fn mul_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_right(scalar, |a, b| a * b, dtype, heap) + } + + /// Element-wise true division (always returns float). + pub fn div(&self, other: &Self, heap: &Heap) -> RunResult { + let (left, right, shape) = broadcast_pair_data( + &self.data, + &self.shape, + &other.data, + &other.shape, + "ndarray true division", + heap.tracker(), + )?; + let data: Vec = left.iter().zip(right.iter()).map(|(&a, &b)| a / b).collect(); + let arr = Self::new(data, shape, NdArrayDtype::Float64); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Division with a scalar value (always returns float). + pub fn div_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_op_right(scalar, |a, b| a / b, NdArrayDtype::Float64, heap) + } + + /// Element-wise floor division. + pub fn floordiv(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_op(other, |a, b| (a / b).floor(), heap) + } + + /// Floor division with a scalar value. + pub fn floordiv_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_right(scalar, |a, b| (a / b).floor(), dtype, heap) + } + + /// Element-wise modulo. + pub fn modulo(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_op(other, py_mod, heap) + } + + /// Modulo with a scalar value. + pub fn modulo_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_right(scalar, py_mod, dtype, heap) + } + + /// Element-wise power. + pub fn pow(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_op(other, f64::powf, heap) + } + + /// Power with a scalar exponent. + pub fn pow_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_right(scalar, f64::powf, dtype, heap) + } + + /// Reverse subtraction with scalar: `scalar - array`. + pub fn rsub_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_left(scalar, |a, b| a - b, dtype, heap) + } + + /// Reverse division with scalar: `scalar / array`. + pub fn rdiv_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_op_left(scalar, |a, b| a / b, NdArrayDtype::Float64, heap) + } + + /// Reverse floor division with scalar: `scalar // array`. + pub fn rfloordiv_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_left(scalar, |a, b| (a / b).floor(), dtype, heap) + } + + /// Reverse modulo with scalar: `scalar % array`. + pub fn rmod_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_left(scalar, py_mod, dtype, heap) + } + + /// Reverse power with scalar: `scalar ** array`. + pub fn rpow_scalar( + &self, + scalar: f64, + scalar_is_float: bool, + heap: &Heap, + ) -> RunResult { + let dtype = promote_dtype_with_scalar(self.dtype, scalar_is_float); + self.scalar_op_left(scalar, f64::powf, dtype, heap) + } + + /// Element-wise bitwise AND between two arrays. + /// + /// - **Bool arrays**: element-wise logical AND. + /// - **Int arrays**: bitwise AND on each pair of elements cast to `i64`. + /// - **Float arrays**: raises `TypeError`, matching NumPy's behavior. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn bitand(&self, other: &Self, heap: &Heap) -> RunResult { + check_bitwise_dtype(self.dtype, "&")?; + check_bitwise_dtype(other.dtype, "&")?; + let result_dtype = if self.dtype == NdArrayDtype::Bool && other.dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + }; + let (left, right, shape) = broadcast_pair_data( + &self.data, + &self.shape, + &other.data, + &other.shape, + "ndarray bitwise and", + heap.tracker(), + )?; + let data: Vec = left + .iter() + .zip(right.iter()) + .map(|(&a, &b)| (a as i64 & b as i64) as f64) + .collect(); + let arr = Self::new(data, shape, result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Bitwise AND with a scalar value. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn bitand_scalar(&self, scalar: i64, heap: &Heap) -> RunResult { + check_bitwise_dtype(self.dtype, "&")?; + let result_dtype = if self.dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + }; + check_array_alloc_size(self.len(), heap.tracker())?; + let data: Vec = self.data.iter().map(|&a| (a as i64 & scalar) as f64).collect(); + let arr = Self::new(data, self.shape.clone(), result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Element-wise bitwise OR between two arrays. + /// + /// - **Bool arrays**: element-wise logical OR. + /// - **Int arrays**: bitwise OR on each pair of elements cast to `i64`. + /// - **Float arrays**: raises `TypeError`, matching NumPy's behavior. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn bitor(&self, other: &Self, heap: &Heap) -> RunResult { + check_bitwise_dtype(self.dtype, "|")?; + check_bitwise_dtype(other.dtype, "|")?; + let result_dtype = if self.dtype == NdArrayDtype::Bool && other.dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + }; + let (left, right, shape) = broadcast_pair_data( + &self.data, + &self.shape, + &other.data, + &other.shape, + "ndarray bitwise or", + heap.tracker(), + )?; + let data: Vec = left + .iter() + .zip(right.iter()) + .map(|(&a, &b)| (a as i64 | b as i64) as f64) + .collect(); + let arr = Self::new(data, shape, result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Bitwise OR with a scalar value. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn bitor_scalar(&self, scalar: i64, heap: &Heap) -> RunResult { + check_bitwise_dtype(self.dtype, "|")?; + let result_dtype = if self.dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + }; + check_array_alloc_size(self.len(), heap.tracker())?; + let data: Vec = self.data.iter().map(|&a| (a as i64 | scalar) as f64).collect(); + let arr = Self::new(data, self.shape.clone(), result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Element-wise bitwise XOR between two arrays. + /// + /// - **Bool arrays**: element-wise logical XOR. + /// - **Int arrays**: bitwise XOR on each pair of elements cast to `i64`. + /// - **Float arrays**: raises `TypeError`, matching NumPy's behavior. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn bitxor(&self, other: &Self, heap: &Heap) -> RunResult { + check_bitwise_dtype(self.dtype, "^")?; + check_bitwise_dtype(other.dtype, "^")?; + let result_dtype = if self.dtype == NdArrayDtype::Bool && other.dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + }; + let (left, right, shape) = broadcast_pair_data( + &self.data, + &self.shape, + &other.data, + &other.shape, + "ndarray bitwise xor", + heap.tracker(), + )?; + let data: Vec = left + .iter() + .zip(right.iter()) + .map(|(&a, &b)| (a as i64 ^ b as i64) as f64) + .collect(); + let arr = Self::new(data, shape, result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Bitwise XOR with a scalar value. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn bitxor_scalar(&self, scalar: i64, heap: &Heap) -> RunResult { + check_bitwise_dtype(self.dtype, "^")?; + let result_dtype = if self.dtype == NdArrayDtype::Bool { + NdArrayDtype::Bool + } else { + NdArrayDtype::Int64 + }; + check_array_alloc_size(self.len(), heap.tracker())?; + let data: Vec = self.data.iter().map(|&a| (a as i64 ^ scalar) as f64).collect(); + let arr = Self::new(data, self.shape.clone(), result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Matrix multiplication (the `@` operator). + /// + /// - **1D @ 1D**: dot product, returns a scalar. + /// - **2D @ 2D**: standard matrix multiplication, returns a 2D array. + /// - **2D @ 1D**: matrix-vector product, returns a 1D array. + /// - **1D @ 2D**: vector-matrix product, returns a 1D array. + pub fn matmul(&self, other: &Self, heap: &Heap) -> RunResult { + let result_dtype = promote_dtype(self.dtype, other.dtype); + match (self.ndim(), other.ndim()) { + (1, 1) => { + // Dot product + if self.data.len() != other.data.len() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "matmul: Input operand 1 does not have enough dimensions", + ) + .into()); + } + let result: f64 = self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a * b).sum(); + if result_dtype == NdArrayDtype::Int64 { + #[expect(clippy::cast_possible_truncation, reason = "intended int truncation")] + return Ok(Value::Int(result as i64)); + } + Ok(Value::Float(result)) + } + (2, 2) => { + // Matrix multiplication + let (m, k1) = (self.shape[0], self.shape[1]); + let (k2, n) = (other.shape[0], other.shape[1]); + if k1 != k2 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!("matmul: Input operand 1 has a mismatch in its core dimension 0, (size {k1} is different from {k2})"), + ) + .into()); + } + let mut data = Vec::with_capacity(m * n); + for i in 0..m { + for j in 0..n { + let mut sum = 0.0; + for p in 0..k1 { + sum += self.data[i * k1 + p] * other.data[p * n + j]; + } + data.push(sum); + } + } + let arr = Self::new(data, vec![m, n], result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + (2, 1) => { + // Matrix-vector product + let (m, k1) = (self.shape[0], self.shape[1]); + if k1 != other.data.len() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "matmul: Input operand 1 has a mismatch in its core dimension 0", + ) + .into()); + } + let mut data = Vec::with_capacity(m); + for i in 0..m { + let mut sum = 0.0; + for p in 0..k1 { + sum += self.data[i * k1 + p] * other.data[p]; + } + data.push(sum); + } + let arr = Self::new(data, vec![m], result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + (1, 2) => { + // Vector-matrix product + let (k2, n) = (other.shape[0], other.shape[1]); + if self.data.len() != k2 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "matmul: Input operand 1 has a mismatch in its core dimension 0", + ) + .into()); + } + let mut data = Vec::with_capacity(n); + for j in 0..n { + let mut sum = 0.0; + for p in 0..k2 { + sum += self.data[p] * other.data[p * n + j]; + } + data.push(sum); + } + let arr = Self::new(data, vec![n], result_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + _ => Err(ExcType::type_error("matmul not supported for these array dimensions")), + } + } + + /// Unary negation. + pub fn neg(&self, heap: &Heap) -> RunResult { + let data: Vec = self.data.iter().map(|&a| -a).collect(); + let arr = Self::new(data, self.shape.clone(), self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Unary bitwise invert (`~`). + /// + /// - **Bool arrays**: flips True↔False (returns Bool dtype). + /// - **Int arrays**: bitwise NOT on each element cast to `i64` (returns Int64 dtype). + /// - **Float arrays**: raises `TypeError`, matching NumPy's behavior. + #[expect( + clippy::cast_possible_truncation, + reason = "f64→i64 truncation is intentional for int-typed ndarray elements" + )] + pub fn invert(&self, heap: &Heap) -> RunResult { + match self.dtype { + NdArrayDtype::Bool => { + let data: Vec = self.data.iter().map(|&a| if a == 0.0 { 1.0 } else { 0.0 }).collect(); + let arr = Self::new(data, self.shape.clone(), NdArrayDtype::Bool); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + NdArrayDtype::Int64 => { + let data: Vec = self.data.iter().map(|&a| !(a as i64) as f64).collect(); + let arr = Self::new(data, self.shape.clone(), NdArrayDtype::Int64); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + NdArrayDtype::Float64 => Err(SimpleException::new_msg( + ExcType::TypeError, + "ufunc 'invert' not supported for the input types", + ) + .into()), + } + } + + /// Element-wise greater-than comparison. + pub fn gt(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_cmp(other, |a, b| a > b, heap) + } + + /// Greater-than comparison with scalar. + pub fn gt_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_cmp(scalar, |a, b| a > b, heap) + } + + /// Element-wise less-than comparison. + pub fn lt(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_cmp(other, |a, b| a < b, heap) + } + + /// Less-than comparison with scalar. + pub fn lt_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_cmp(scalar, |a, b| a < b, heap) + } + + /// Element-wise equality comparison. + pub fn eq_array(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_cmp(other, |a, b| a == b, heap) + } + + /// Equality comparison with scalar. + pub fn eq_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_cmp(scalar, |a, b| a == b, heap) + } + + /// Element-wise greater-than-or-equal comparison. + pub fn gte(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_cmp(other, |a, b| a >= b, heap) + } + + /// Greater-than-or-equal comparison with scalar. + pub fn gte_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_cmp(scalar, |a, b| a >= b, heap) + } + + /// Element-wise less-than-or-equal comparison. + pub fn lte(&self, other: &Self, heap: &Heap) -> RunResult { + self.elementwise_cmp(other, |a, b| a <= b, heap) + } + + /// Less-than-or-equal comparison with scalar. + pub fn lte_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + self.scalar_cmp(scalar, |a, b| a <= b, heap) + } + + /// Element-wise not-equal comparison. + pub fn ne_array(&self, other: &Self, heap: &Heap) -> RunResult { + #[expect(clippy::float_cmp, reason = "exact equality is correct for numpy != semantics")] + self.elementwise_cmp(other, |a, b| a != b, heap) + } + + /// Not-equal comparison with scalar. + pub fn ne_scalar(&self, scalar: f64, heap: &Heap) -> RunResult { + #[expect(clippy::float_cmp, reason = "exact equality is correct for numpy != semantics")] + self.scalar_cmp(scalar, |a, b| a != b, heap) + } +} + +// =========================== +// Aggregation methods +// =========================== + +impl NdArray { + /// Returns the sum of all elements. + pub fn sum(&self) -> f64 { + self.data.iter().sum() + } + + /// Returns the arithmetic mean of all elements. + pub fn mean(&self) -> f64 { + self.sum() / self.len() as f64 + } + + /// Returns the minimum element. + /// + /// If any element is NaN, returns NaN — matching NumPy's propagation semantics. + /// Uses a NaN-propagating reduction: once a NaN is seen, the result is NaN. + pub fn min_val(&self) -> RunResult { + self.data + .iter() + .copied() + .reduce(|acc, v| { + if acc.is_nan() || v.is_nan() { + f64::NAN + } else { + acc.min(v) + } + }) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "zero-size array has no minimum").into()) + } + + /// Returns the maximum element. + /// + /// If any element is NaN, returns NaN — matching NumPy's propagation semantics. + pub fn max_val(&self) -> RunResult { + self.data + .iter() + .copied() + .reduce(|acc, v| { + if acc.is_nan() || v.is_nan() { + f64::NAN + } else { + acc.max(v) + } + }) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "zero-size array has no maximum").into()) + } + + /// Returns the population standard deviation. + pub fn std_dev(&self) -> f64 { + let mean = self.mean(); + let variance = self.data.iter().map(|&x| (x - mean).powi(2)).sum::() / self.len() as f64; + variance.sqrt() + } + + /// Returns the index of the minimum element. + pub fn argmin(&self) -> RunResult { + self.data + .iter() + .enumerate() + .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal)) + .map(|(i, _)| i) + .ok_or_else(|| { + SimpleException::new_msg(ExcType::ValueError, "attempt to get argmin of an empty sequence").into() + }) + } + + /// Returns the index of the maximum element. + pub fn argmax(&self) -> RunResult { + self.data + .iter() + .enumerate() + .reduce(|(i_max, v_max), (i, v)| { + if v.partial_cmp(v_max).unwrap_or(Ordering::Equal) == Ordering::Greater { + (i, v) + } else { + (i_max, v_max) + } + }) + .map(|(i, _)| i) + .ok_or_else(|| { + SimpleException::new_msg(ExcType::ValueError, "attempt to get argmax of an empty sequence").into() + }) + } + + /// Returns true if all elements are truthy (non-zero). + pub fn all(&self) -> bool { + self.data.iter().all(|&x| x != 0.0) + } + + /// Returns true if any element is truthy (non-zero). + pub fn any(&self) -> bool { + self.data.iter().any(|&x| x != 0.0) + } + + /// Returns the product of all elements. + pub fn prod(&self) -> f64 { + self.data.iter().copied().fold(1.0, |acc, v| acc * v) + } + + /// Returns the population variance (ddof=0). + pub fn var(&self) -> f64 { + let mean = self.mean(); + self.data.iter().map(|&x| (x - mean).powi(2)).sum::() / self.len() as f64 + } +} + +// =========================== +// Shape manipulation and transform methods +// =========================== + +impl NdArray { + /// Reshapes the array to a new shape, returning a new NdArray. + /// + /// The total number of elements must remain the same. + pub fn reshape(&self, new_shape: Vec, heap: &Heap) -> RunResult { + let new_size: usize = new_shape + .iter() + .try_fold(1usize, |acc, &d| acc.checked_mul(d)) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "reshape dimensions overflow"))?; + if new_size != self.len() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!( + "cannot reshape array of size {} into shape ({})", + self.len(), + new_shape.iter().map(ToString::to_string).collect::>().join(", ") + ), + ) + .into()); + } + let arr = Self::new(self.data.clone(), new_shape, self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Flattens the array to 1D, returning a new NdArray. + pub fn flatten(&self, heap: &Heap) -> RunResult { + let arr = Self::new(self.data.clone(), vec![self.len()], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Returns the transpose of a 2D array. For 1D arrays, returns a copy. + pub fn transpose(&self, heap: &Heap) -> RunResult { + if self.ndim() <= 1 { + // 1D arrays are returned as-is (copy) + let arr = Self::new(self.data.clone(), self.shape.clone(), self.dtype); + return Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)); + } + if self.ndim() == 2 { + let rows = self.shape[0]; + let cols = self.shape[1]; + let mut data = vec![0.0; self.data.len()]; + for r in 0..rows { + for c in 0..cols { + data[c * rows + r] = self.data[r * cols + c]; + } + } + let arr = Self::new(data, vec![cols, rows], self.dtype); + return Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)); + } + Err(ExcType::type_error("transpose not supported for arrays with ndim > 2")) + } + + /// `cumsum()` — returns a 1D array of cumulative sums. + pub fn cumsum(&self, heap: &Heap) -> RunResult { + let mut sum = 0.0; + let data: Vec = self + .data + .iter() + .map(|&v| { + sum += v; + sum + }) + .collect(); + let arr = Self::new(data, vec![self.len()], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `cumprod()` — returns a 1D array of cumulative products. + /// + /// Each element is the product of all preceding elements (inclusive). + /// The result is always 1D, matching NumPy's behavior when no axis is specified. + pub fn cumprod(&self, heap: &Heap) -> RunResult { + let mut acc = 1.0; + let data: Vec = self + .data + .iter() + .map(|&v| { + acc *= v; + acc + }) + .collect(); + let arr = Self::new(data, vec![self.len()], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `item()` — return the single element of a size-1 array as a Python scalar. + /// + /// Raises `ValueError` if the array has more than one element, matching NumPy. + pub fn item(&self) -> RunResult { + if self.data.len() != 1 { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "can only convert an array of size 1 to a Python scalar", + ) + .into()); + } + Ok(self.element_to_value(self.data[0])) + } + + /// `squeeze()` — remove axes of length 1, returning a new array. + /// + /// If all axes have length 1, the result is a 1-element array with shape `(1,)`. + pub fn squeeze(&self, heap: &Heap) -> RunResult { + let shape: Vec = self.shape.iter().copied().filter(|&s| s != 1).collect(); + let shape = if shape.is_empty() { vec![1] } else { shape }; + let arr = Self::new(self.data.clone(), shape, self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `take(indices)` — gather elements at the given integer indices. + /// + /// The indices array is flattened and used to index into the flattened source array. + /// Negative indices are supported. + #[expect( + clippy::cast_possible_truncation, + reason = "index from f64 is intentional for int-typed ndarray elements" + )] + pub fn take_indices(&self, indices: &Self, heap: &Heap) -> RunResult { + let mut data = Vec::with_capacity(indices.len()); + for &idx_f in indices.data() { + let idx = idx_f as i64; + let resolved = resolve_index(idx, self.data.len())?; + data.push(self.data[resolved]); + } + let len = data.len(); + let arr = Self::new(data, vec![len], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `diagonal()` — return the diagonal of a 2D array. + /// + /// Raises `ValueError` for arrays with fewer than 2 dimensions. + pub fn diagonal(&self, heap: &Heap) -> RunResult { + if self.shape.len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "diagonal requires 2-d array").into()); + } + let rows = self.shape[0]; + let cols = self.shape[1]; + let n = rows.min(cols); + let data: Vec = (0..n).map(|i| self.data[i * cols + i]).collect(); + let arr = Self::new(data, vec![n], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `trace()` — return the sum of the diagonal elements of a 2D array. + /// + /// Returns an int for int arrays and a float for float arrays, matching NumPy. + /// Raises `ValueError` for arrays with fewer than 2 dimensions. + pub fn trace(&self) -> RunResult { + if self.shape.len() < 2 { + return Err(SimpleException::new_msg(ExcType::ValueError, "trace requires 2-d array").into()); + } + let cols = self.shape[1]; + let n = self.shape[0].min(cols); + let sum: f64 = (0..n).map(|i| self.data[i * cols + i]).sum(); + Ok(self.element_to_value(sum)) + } + + /// `fill(value)` — fill the array in-place with the given scalar value. + pub fn fill(&mut self, value: f64) { + self.data.fill(value); + } + + /// `compress(condition)` — return elements where the boolean condition array is true. + /// + /// Operates on the flattened array. The condition array's truthy elements select + /// corresponding elements from the source. + pub fn compress(&self, condition: &Self, heap: &Heap) -> RunResult { + let data: Vec = self + .data + .iter() + .zip(condition.data.iter()) + .filter(|pair| *pair.1 != 0.0) + .map(|pair| *pair.0) + .collect(); + let len = data.len(); + let arr = Self::new(data, vec![len], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `swapaxes(a, b)` — swap two axes of the array. + /// + /// For 2D arrays with axes 0 and 1, this is equivalent to a transpose. + /// For 1D arrays (or swapping an axis with itself), returns a copy. + pub fn swapaxes(&self, axis_a: usize, axis_b: usize, heap: &Heap) -> RunResult { + if axis_a >= self.ndim() || axis_b >= self.ndim() { + return Err(SimpleException::new_msg( + ExcType::ValueError, + format!("bad axis for array with {} dimensions", self.ndim()), + ) + .into()); + } + if axis_a == axis_b || self.ndim() <= 1 { + // No-op: return a copy + let arr = Self::new(self.data.clone(), self.shape.clone(), self.dtype); + return Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)); + } + // For 2D with axes (0, 1), this is a transpose + if self.ndim() == 2 { + self.transpose(heap) + } else { + Err(ExcType::type_error("swapaxes not supported for arrays with ndim > 2")) + } + } + + /// `repeat(n)` — repeat each element `n` times, returning a 1D array. + pub fn repeat_array(&self, n: usize, heap: &Heap) -> RunResult { + let len = self + .data + .len() + .checked_mul(n) + .ok_or_else(|| SimpleException::new_msg(ExcType::ValueError, "array dimensions overflow"))?; + check_array_alloc_size(len, heap.tracker())?; + let mut data = Vec::with_capacity(len); + for &v in &self.data { + for _ in 0..n { + data.push(v); + } + } + let arr = Self::new(data, vec![len], self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `nonzero()` — returns a tuple of arrays, one per dimension, containing non-zero indices. + /// + /// For 1D arrays, returns a 1-element tuple of an index array. + /// For 2D arrays, returns a 2-element tuple of row and column index arrays. + pub fn nonzero_method(&self, heap: &Heap) -> RunResult { + if self.ndim() <= 1 { + let indices: Vec = self + .data + .iter() + .enumerate() + .filter(|(_, v)| **v != 0.0) + .map(|(i, _)| i as f64) + .collect(); + let len = indices.len(); + let arr = Self::new(indices, vec![len], NdArrayDtype::Int64); + let arr_val = Value::Ref(heap.allocate(HeapData::NdArray(arr))?); + let tup = allocate_tuple(smallvec![arr_val], heap)?; + Ok(tup) + } else if self.ndim() == 2 { + let rows = self.shape[0]; + let cols = self.shape[1]; + let mut row_indices = Vec::new(); + let mut col_indices = Vec::new(); + for r in 0..rows { + for c in 0..cols { + if self.data[r * cols + c] != 0.0 { + row_indices.push(r as f64); + col_indices.push(c as f64); + } + } + } + let row_len = row_indices.len(); + let col_len = col_indices.len(); + let row_arr = Self::new(row_indices, vec![row_len], NdArrayDtype::Int64); + let col_arr = Self::new(col_indices, vec![col_len], NdArrayDtype::Int64); + let row_val = Value::Ref(heap.allocate(HeapData::NdArray(row_arr))?); + let col_val = Value::Ref(heap.allocate(HeapData::NdArray(col_arr))?); + let tup = allocate_tuple(smallvec![row_val, col_val], heap)?; + Ok(tup) + } else { + Err(ExcType::type_error("nonzero() not supported for arrays with ndim > 2")) + } + } + + /// `round(decimals)` — returns a new array with each element rounded. + pub fn round_array(&self, decimals: i32, heap: &Heap) -> RunResult { + let factor = 10f64.powi(decimals); + let data: Vec = self.data.iter().map(|&v| (v * factor).round() / factor).collect(); + let arr = Self::new(data, self.shape.clone(), NdArrayDtype::Float64); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `clip(min, max)` — returns a new array with each element clamped to `[min, max]`. + pub fn clip_array(&self, min: f64, max: f64, heap: &Heap) -> RunResult { + let data: Vec = self.data.iter().map(|&v| v.clamp(min, max)).collect(); + let arr = Self::new(data, self.shape.clone(), self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `sort()` — sort the array data in-place (mutates `self`). + /// + /// NaN values sort to the end, matching NumPy's `ndarray.sort()` behavior. + pub fn sort_in_place(&mut self) { + self.data.sort_by(nan_last_cmp); + } + + /// `argsort()` — returns indices that would sort the array. + /// + /// NaN values sort to the end, matching NumPy's behavior. + pub fn argsort(&self, heap: &Heap) -> RunResult { + let mut indices: Vec = (0..self.data.len()).collect(); + indices.sort_by(|&a, &b| nan_last_cmp(&self.data[a], &self.data[b])); + let data: Vec = indices.iter().map(|&i| i as f64).collect(); + let arr = Self::new(data, vec![self.data.len()], NdArrayDtype::Int64); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `astype(dtype_str)` — cast array to a new dtype. + /// + /// Accepts NumPy dtype aliases that Monty maps onto its compact dtype set. + pub fn astype(&self, dtype_str: &str, heap: &Heap) -> RunResult { + let Some(new_dtype) = ndarray_dtype_from_numpy_name(dtype_str) else { + return Err(SimpleException::new_msg(ExcType::TypeError, format!("unsupported dtype: {dtype_str}")).into()); + }; + let data = match new_dtype { + NdArrayDtype::Bool => self.data.iter().map(|&v| if v == 0.0 { 0.0 } else { 1.0 }).collect(), + NdArrayDtype::Int64 => + { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is the intended int conversion" + )] + self.data.iter().map(|&v| (v as i64) as f64).collect() + } + NdArrayDtype::Float64 => self.data.clone(), + }; + let arr = Self::new(data, self.shape.clone(), new_dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// `dot(other)` — dot product of two 1D arrays, returning a scalar. + pub fn dot(&self, other: &Self) -> RunResult { + if self.data.len() != other.data.len() { + return Err(SimpleException::new_msg(ExcType::ValueError, "shapes are not aligned for dot product").into()); + } + let result: f64 = self.data.iter().zip(other.data.iter()).map(|(&a, &b)| a * b).sum(); + let result_dtype = promote_dtype(self.dtype, other.dtype); + if result_dtype == NdArrayDtype::Int64 { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is intended for int dot product" + )] + return Ok(Value::Int(result as i64)); + } + Ok(Value::Float(result)) + } + + /// Returns a copy of this ndarray. + pub fn copy_array(&self, heap: &Heap) -> RunResult { + let arr = Self::new(self.data.clone(), self.shape.clone(), self.dtype); + Ok(Value::Ref(heap.allocate(HeapData::NdArray(arr))?)) + } + + /// Converts the array to a (possibly nested) Python list. + /// + /// For zero-dimensional arrays, returns the scalar value. For 1D arrays, + /// returns a flat list. For 2D+ arrays, returns nested lists matching the + /// shape, e.g. shape `(2, 3)` → `[[1, 2, 3], [4, 5, 6]]`. + pub fn tolist(&self, heap: &Heap) -> RunResult { + self.tolist_recursive(&self.shape, 0, heap) + } + + /// Recursively builds nested lists for `tolist()`. + fn tolist_recursive( + &self, + remaining_shape: &[usize], + offset: usize, + heap: &Heap, + ) -> RunResult { + if remaining_shape.is_empty() { + Ok(self.element_to_value(self.data[offset])) + } else if remaining_shape.len() == 1 { + // Leaf dimension: flat list of scalars + let len = remaining_shape[0]; + let values: Vec = (0..len).map(|i| self.element_to_value(self.data[offset + i])).collect(); + let list = List::new(values); + Ok(Value::Ref(heap.allocate(HeapData::List(list))?)) + } else { + // Build nested list: each element is a sub-list + let sub_size: usize = remaining_shape[1..].iter().product(); + let mut values = Vec::with_capacity(remaining_shape[0]); + for i in 0..remaining_shape[0] { + values.push(self.tolist_recursive(&remaining_shape[1..], offset + i * sub_size, heap)?); + } + let list = List::new(values); + Ok(Value::Ref(heap.allocate(HeapData::List(list))?)) + } + } +} + +/// Maps NumPy dtype names and compact aliases onto Monty's storage dtypes. +/// +/// Monty stores the supported numeric subset as bool, int64, or float64. Narrow +/// integer and float aliases are accepted for compatibility but intentionally +/// collapse to the closest storage dtype rather than introducing unplanned +/// memory-layout semantics. +fn ndarray_dtype_from_numpy_name(dtype_str: &str) -> Option { + match dtype_str { + "bool" | "bool_" | "?" => Some(NdArrayDtype::Bool), + "int8" | "int16" | "int32" | "int64" | "int_" | "intc" | "intp" | "long" | "longlong" | "byte" | "short" + | "uint8" | "uint16" | "uint32" | "uint64" | "uint" | "uintc" | "uintp" | "ubyte" | "ushort" | "ulong" + | "ulonglong" | "int" | "i" | "l" | "q" | "b" | "h" | "B" | "H" | "I" | "L" | "Q" => Some(NdArrayDtype::Int64), + "float16" | "float32" | "float64" | "half" | "single" | "double" | "longdouble" | "float" | "f" | "e" | "d" + | "g" => Some(NdArrayDtype::Float64), + _ => None, + } +} + +// =========================== +// Attribute accessors +// =========================== + +impl NdArray { + /// Returns the shape as a Python tuple of ints. + pub fn shape_tuple(&self, heap: &Heap) -> Result { + #[expect(clippy::cast_possible_wrap, reason = "shape dimensions won't exceed i64::MAX")] + let values: SmallVec<[Value; 3]> = self.shape.iter().map(|&d| Value::Int(d as i64)).collect(); + allocate_tuple(values, heap) + } + + /// Returns the dtype as a string Value. + pub fn dtype_str(&self, heap: &Heap) -> Result { + let s = Str::new(self.dtype.to_string()); + Ok(Value::Ref(heap.allocate(HeapData::Str(s))?)) + } +} + +// =========================== +// Repr formatting +// =========================== + +impl NdArray { + /// Writes the repr format to the given formatter. + /// + /// Produces output like `array([1, 2, 3])` for int arrays + /// or `array([1., 2., 3.])` for float arrays. + pub fn py_repr_fmt_inner(&self, f: &mut impl Write) -> fmt::Result { + f.write_str("array(")?; + self.write_recursive(f, &self.shape, 0)?; + // NumPy includes dtype suffix for empty arrays since element format can't convey it + if self.data.is_empty() { + write!(f, ", dtype={}", self.dtype)?; + } + f.write_char(')') + } + + /// Writes NumPy's bare array string format without the `array(...)` wrapper. + /// + /// This is used by module-level display helpers such as `numpy.array2string` + /// and `numpy.array_str`. It intentionally mirrors Monty's existing compact + /// ndarray subset and avoids introducing global print-option state. + pub fn array_str_fmt_inner(&self, f: &mut impl Write) -> fmt::Result { + if self.shape.is_empty() { + if let Some(value) = self.data.first() { + self.write_array_string_value(f, *value) + } else { + f.write_str("[]") + } + } else { + self.write_array_string_recursive(f, &self.shape, 0, 0) + } + } + + /// Recursively writes nested list representation for multi-dimensional arrays. + fn write_recursive(&self, f: &mut impl Write, remaining_shape: &[usize], offset: usize) -> fmt::Result { + if remaining_shape.len() == 1 { + f.write_char('[')?; + let len = remaining_shape[0]; + for i in 0..len { + if i > 0 { + f.write_str(", ")?; + } + let val = self.data[offset + i]; + match self.dtype { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is the intended int display" + )] + NdArrayDtype::Int64 => write!(f, "{}", val as i64)?, + NdArrayDtype::Float64 => { + if val.is_nan() { + f.write_str("nan")?; + } else if val.is_infinite() { + if val.is_sign_negative() { + f.write_str("-inf")?; + } else { + f.write_str("inf")?; + } + } else if val.fract() == 0.0 { + // NumPy displays whole floats as "1." not "1.0" + write!(f, "{val:.0}.")?; + } else { + write!(f, "{val}")?; + } + } + NdArrayDtype::Bool => { + if val == 0.0 { + f.write_str("False")?; + } else { + f.write_str(" True")?; + } + } + } + } + f.write_char(']') + } else { + f.write_char('[')?; + let sub_size: usize = remaining_shape[1..].iter().product(); + for i in 0..remaining_shape[0] { + if i > 0 { + f.write_str(", ")?; + } + self.write_recursive(f, &remaining_shape[1..], offset + i * sub_size)?; + } + f.write_char(']') + } + } + + /// Recursively writes NumPy's comma-free display format. + fn write_array_string_recursive( + &self, + f: &mut impl Write, + remaining_shape: &[usize], + offset: usize, + depth: usize, + ) -> fmt::Result { + if remaining_shape.len() == 1 { + self.write_array_string_leaf(f, remaining_shape[0], offset) + } else { + f.write_char('[')?; + let sub_size: usize = remaining_shape[1..].iter().product(); + for i in 0..remaining_shape[0] { + if i > 0 { + f.write_char('\n')?; + for _ in 0..=depth { + f.write_char(' ')?; + } + } + self.write_array_string_recursive(f, &remaining_shape[1..], offset + i * sub_size, depth + 1)?; + } + f.write_char(']') + } + } + + /// Writes one flat row for NumPy's bare display format. + fn write_array_string_leaf(&self, f: &mut impl Write, len: usize, offset: usize) -> fmt::Result { + f.write_char('[')?; + let values = (0..len) + .map(|i| self.array_string_value(self.data[offset + i])) + .collect::>(); + let width = values.iter().map(String::len).max().unwrap_or(0); + for (index, text) in values.iter().enumerate() { + if index > 0 { + f.write_char(' ')?; + } + match self.dtype { + NdArrayDtype::Int64 => write!(f, "{text:>width$}")?, + NdArrayDtype::Float64 => write!(f, "{text: f.write_str(text)?, + } + } + f.write_char(']') + } + + /// Writes a scalar value in NumPy's bare array display format. + fn write_array_string_value(&self, f: &mut impl Write, value: f64) -> fmt::Result { + f.write_str(&self.array_string_value(value)) + } + + /// Formats a scalar value for NumPy's bare array display format. + fn array_string_value(&self, value: f64) -> String { + match self.dtype { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is the intended int display" + )] + NdArrayDtype::Int64 => (value as i64).to_string(), + NdArrayDtype::Float64 => { + if value.is_nan() { + "nan".to_string() + } else if value.is_infinite() { + if value.is_sign_negative() { + "-inf".to_string() + } else { + "inf".to_string() + } + } else if value.fract() == 0.0 { + format!("{value:.0}.") + } else { + value.to_string() + } + } + NdArrayDtype::Bool => { + if value == 0.0 { + "False".to_string() + } else { + " True".to_string() + } + } + } + } +} + +// =========================== +// PyTrait implementation via HeapRead +// =========================== + +impl<'h> PyTrait<'h> for HeapRead<'h, NdArray> { + fn py_type(&self, vm: &VM<'h, impl ResourceTracker>) -> Type { + if self.get(vm.heap).is_flatiter { + Type::FlatIter + } else { + Type::NdArray + } + } + + fn py_len(&self, vm: &VM<'h, impl ResourceTracker>) -> Option { + // NumPy's len() returns the size of the first dimension, not total elements. + let arr = self.get(vm.heap); + Some(arr.shape().first().copied().unwrap_or(0)) + } + + fn py_eq(&self, other: &Self, vm: &mut VM<'h, impl ResourceTracker>) -> Result { + let a = self.get(vm.heap); + let b = other.get(vm.heap); + Ok(a.shape == b.shape && a.data == b.data && a.dtype == b.dtype) + } + + fn py_bool(&self, vm: &mut VM<'h, impl ResourceTracker>) -> bool { + let arr = self.get(vm.heap); + // NumPy only allows bool() on single-element arrays. + // For 0 or >1 elements, NumPy raises ValueError — but the py_bool trait + // returns bool, not Result, so we fall back to non-empty check. + // TODO: propagate ValueError when py_bool returns RunResult. + if arr.len() == 1 { + arr.data[0] != 0.0 + } else { + !arr.data.is_empty() + } + } + + fn py_repr_fmt( + &self, + f: &mut impl Write, + vm: &mut VM<'h, impl ResourceTracker>, + _heap_ids: &mut AHashSet, + ) -> RunResult<()> { + Ok(self.get(vm.heap).py_repr_fmt_inner(f)?) + } + + fn py_getitem(&self, key: &Value, vm: &mut VM<'h, impl ResourceTracker>) -> RunResult { + let arr = self.get(vm.heap); + match key { + Value::Int(n) => arr.getitem_int(*n, vm.heap), + Value::Bool(b) => arr.getitem_int(i64::from(*b), vm.heap), + Value::Ref(key_id) => { + match vm.heap.get(*key_id) { + HeapData::NdArray(mask_or_idx) => { + if mask_or_idx.dtype() == NdArrayDtype::Bool { + arr.getitem_bool_mask(mask_or_idx, vm.heap) + } else { + // Integer array indexing (fancy indexing) + arr.getitem_int_array(mask_or_idx, vm.heap) + } + } + HeapData::Slice(slice) => arr.getitem_slice(slice, vm.heap), + _ => Err(ExcType::type_error( + "ndarray indices must be integers, slices, or boolean/integer arrays", + )), + } + } + _ => Err(ExcType::type_error( + "ndarray indices must be integers, slices, or boolean/integer arrays", + )), + } + } + + fn py_setitem(&mut self, key: Value, value: Value, vm: &mut VM<'h, impl ResourceTracker>) -> RunResult<()> { + defer_drop!(key, vm); + defer_drop!(value, vm); + + match *key { + // arr[int] = val — set a single element by integer index + Value::Int(idx) => { + let scalar = extract_f64(value)?; + let arr = self.get_mut(vm.heap); + if arr.ndim() != 1 { + return Err(ExcType::type_error("only 1D array integer assignment is supported")); + } + let resolved = resolve_index(idx, arr.shape[0])?; + arr.data[resolved] = scalar; + Ok(()) + } + Value::Bool(b) => { + let scalar = extract_f64(value)?; + let arr = self.get_mut(vm.heap); + if arr.ndim() != 1 { + return Err(ExcType::type_error("only 1D array integer assignment is supported")); + } + let resolved = resolve_index(i64::from(b), arr.shape[0])?; + arr.data[resolved] = scalar; + Ok(()) + } + Value::Ref(key_id) => { + match vm.heap.get(key_id) { + // arr[bool_mask] = val — set elements where mask is True + HeapData::NdArray(mask) if mask.dtype() == NdArrayDtype::Bool => { + let mask_data: Vec = mask.data().iter().map(|&v| v != 0.0).collect(); + let scalar = extract_f64(value)?; + let arr = self.get_mut(vm.heap); + if mask_data.len() != arr.data.len() { + return Err(SimpleException::new_msg( + ExcType::IndexError, + "boolean index did not match indexed array", + ) + .into()); + } + for (i, &m) in mask_data.iter().enumerate() { + if m { + arr.data[i] = scalar; + } + } + Ok(()) + } + // arr[slice] = val — set slice of elements (scalar or array) + HeapData::Slice(slice) => { + // Extract RHS values: scalar broadcasts, array assigns element-wise + let rhs_data: Option> = match *value { + Value::Ref(val_id) => match vm.heap.get(val_id) { + HeapData::NdArray(rhs_arr) => Some(rhs_arr.data().to_vec()), + _ => None, + }, + _ => None, + }; + let len = self.get(vm.heap).data.len(); + let (start, stop, step) = slice.indices(len)?; + let target_len = slice_assignment_len(start, stop, step); + if let Some(ref rhs) = rhs_data { + validate_slice_assignment_length(rhs.len(), target_len)?; + } + let scalar = if rhs_data.is_none() { extract_f64(value)? } else { 0.0 }; + let arr = self.get_mut(vm.heap); + if step > 0 { + let mut i = start; + let mut rhs_idx = 0usize; + while i < stop { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "positive-step slice indices are clamped to the array bounds" + )] + { + arr.data[i as usize] = if let Some(ref rhs) = rhs_data { + rhs[if rhs.len() == 1 { 0 } else { rhs_idx }] + } else { + scalar + }; + } + rhs_idx += 1; + i += step; + } + } else { + let mut i = start; + let mut rhs_idx = 0usize; + while i > stop { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "negative-step slice indices visited here are in bounds" + )] + { + arr.data[i as usize] = if let Some(ref rhs) = rhs_data { + rhs[if rhs.len() == 1 { 0 } else { rhs_idx }] + } else { + scalar + }; + } + rhs_idx += 1; + i += step; + } + } + Ok(()) + } + _ => Err(ExcType::type_error( + "ndarray indices must be integers, slices, or boolean arrays", + )), + } + } + _ => Err(ExcType::type_error( + "ndarray indices must be integers, slices, or boolean arrays", + )), + } + } + + fn py_getattr(&self, attr: &EitherStr, vm: &mut VM<'h, impl ResourceTracker>) -> RunResult> { + let arr = self.get(vm.heap); + let result = match attr.static_string() { + Some(StaticStrings::NpShape) => arr.shape_tuple(vm.heap)?, + Some(StaticStrings::Dtype) => arr.dtype_str(vm.heap)?, + #[expect(clippy::cast_possible_wrap, reason = "array length won't exceed i64::MAX")] + Some(StaticStrings::NpSize) => Value::Int(arr.len() as i64), + #[expect(clippy::cast_possible_wrap, reason = "ndim is always small")] + Some(StaticStrings::NpNdim) => Value::Int(arr.ndim() as i64), + #[expect(clippy::cast_possible_wrap, reason = "nbytes won't exceed i64::MAX")] + Some(StaticStrings::NpNbytes) => Value::Int((arr.len() * 8) as i64), + Some(StaticStrings::NpItemsize) => Value::Int(8), + Some(StaticStrings::NpFlat) => { + let flat = NdArray::new(arr.data.clone(), vec![arr.data.len()], arr.dtype).into_flatiter(); + Value::Ref(vm.heap.allocate(HeapData::NdArray(flat))?) + } + Some(StaticStrings::NpT) => arr.transpose(vm.heap)?, + _ => { + // "T" is a single ASCII character so it is interned as an ASCII StringId, + // not as a StaticStrings variant — handle it in the fallback arm. + if attr.as_str(vm.interns) == "T" { + arr.transpose(vm.heap)? + } else { + return Err(ExcType::attribute_error(Type::NdArray, attr.as_str(vm.interns))); + } + } + }; + Ok(Some(CallResult::Value(result))) + } + + fn py_call_attr( + &mut self, + _self_id: HeapId, + vm: &mut VM<'h, impl ResourceTracker>, + attr: &EitherStr, + args: ArgValues, + ) -> RunResult { + let result = match attr.static_string() { + Some(StaticStrings::NpSum) => { + args.check_zero_args("ndarray.sum", vm.heap)?; + Ok(call_sum(self.get(vm.heap))) + } + Some(StaticStrings::Mean) => { + args.check_zero_args("ndarray.mean", vm.heap)?; + Ok(Value::Float(self.get(vm.heap).mean())) + } + Some(StaticStrings::NpMin) => { + args.check_zero_args("ndarray.min", vm.heap)?; + call_min(self.get(vm.heap)) + } + Some(StaticStrings::NpMax) => { + args.check_zero_args("ndarray.max", vm.heap)?; + call_max(self.get(vm.heap)) + } + Some(StaticStrings::Std) => { + args.check_zero_args("ndarray.std", vm.heap)?; + Ok(Value::Float(self.get(vm.heap).std_dev())) + } + Some(StaticStrings::Flatten) => { + args.check_zero_args("ndarray.flatten", vm.heap)?; + self.get(vm.heap).flatten(vm.heap) + } + Some(StaticStrings::Tolist) => { + args.check_zero_args("ndarray.tolist", vm.heap)?; + self.get(vm.heap).tolist(vm.heap) + } + Some(StaticStrings::Copy) => { + args.check_zero_args("ndarray.copy", vm.heap)?; + self.get(vm.heap).copy_array(vm.heap) + } + Some(StaticStrings::Sort) => { + args.check_zero_args("ndarray.sort", vm.heap)?; + self.get_mut(vm.heap).sort_in_place(); + Ok(Value::None) + } + Some(StaticStrings::NpArgsort) => { + args.check_zero_args("ndarray.argsort", vm.heap)?; + self.get(vm.heap).argsort(vm.heap) + } + Some(StaticStrings::Argmin) => { + args.check_zero_args("ndarray.argmin", vm.heap)?; + #[expect(clippy::cast_possible_wrap, reason = "array index won't exceed i64::MAX")] + Ok(Value::Int(self.get(vm.heap).argmin()? as i64)) + } + Some(StaticStrings::Argmax) => { + args.check_zero_args("ndarray.argmax", vm.heap)?; + #[expect(clippy::cast_possible_wrap, reason = "array index won't exceed i64::MAX")] + Ok(Value::Int(self.get(vm.heap).argmax()? as i64)) + } + Some(StaticStrings::NpAll) => { + args.check_zero_args("ndarray.all", vm.heap)?; + Ok(Value::Bool(self.get(vm.heap).all())) + } + Some(StaticStrings::NpAny) => { + args.check_zero_args("ndarray.any", vm.heap)?; + Ok(Value::Bool(self.get(vm.heap).any())) + } + Some(StaticStrings::Cumsum) => { + args.check_zero_args("ndarray.cumsum", vm.heap)?; + self.get(vm.heap).cumsum(vm.heap) + } + Some(StaticStrings::Reshape) => { + let pos = args.into_pos_only("ndarray.reshape", vm.heap)?; + let result = call_reshape(self.get(vm.heap), pos.as_slice(), vm.heap); + pos.drop_with_heap(vm); + result + } + Some(StaticStrings::Round) => { + let opt = args.get_zero_one_arg("ndarray.round", vm.heap)?; + #[expect(clippy::cast_possible_truncation, reason = "decimals value from user input")] + let decimals = match opt { + Some(Value::Int(n)) => n as i32, + Some(other) => { + other.drop_with_heap(vm); + return Err(ExcType::type_error("decimals must be an integer")); + } + None => 0, + }; + self.get(vm.heap).round_array(decimals, vm.heap) + } + Some(StaticStrings::Clip) => { + let pos = args.into_pos_only("ndarray.clip", vm.heap)?; + let result = if pos.as_slice().len() >= 2 { + match (extract_f64(&pos.as_slice()[0]), extract_f64(&pos.as_slice()[1])) { + (Ok(min_val), Ok(max_val)) => self.get(vm.heap).clip_array(min_val, max_val, vm.heap), + (Err(err), _) | (_, Err(err)) => Err(err), + } + } else { + Err(ExcType::type_error("clip() requires min and max arguments")) + }; + pos.drop_with_heap(vm); + result + } + Some(StaticStrings::Dot) => { + let other_val = args.get_one_arg("ndarray.dot", vm.heap)?; + let result = match &other_val { + Value::Ref(other_id) => { + if let HeapData::NdArray(other) = vm.heap.get(*other_id) { + self.get(vm.heap).dot(other) + } else { + Err(ExcType::type_error("dot() requires an ndarray argument")) + } + } + _ => Err(ExcType::type_error("dot() requires an ndarray argument")), + }; + other_val.drop_with_heap(vm); + result + } + Some(StaticStrings::NpAstype) => { + let arg = args.get_one_arg("ndarray.astype", vm.heap)?; + let result = match &arg { + Value::InternString(id) => { + let name = vm.interns.get_str(*id); + self.get(vm.heap).astype(name, vm.heap) + } + Value::Ref(id) => { + if let HeapData::Str(s) = vm.heap.get(*id) { + let name = s.as_str().to_owned(); + self.get(vm.heap).astype(&name, vm.heap) + } else { + Err(ExcType::type_error("astype() requires a string argument")) + } + } + _ => Err(ExcType::type_error("astype() requires a string argument")), + }; + arg.drop_with_heap(vm); + result + } + Some(StaticStrings::NpTranspose) => { + args.check_zero_args("ndarray.transpose", vm.heap)?; + self.get(vm.heap).transpose(vm.heap) + } + Some(StaticStrings::NpProd) => { + args.check_zero_args("ndarray.prod", vm.heap)?; + Ok(call_prod_method(self.get(vm.heap))) + } + Some(StaticStrings::NpVar) => { + args.check_zero_args("ndarray.var", vm.heap)?; + Ok(Value::Float(self.get(vm.heap).var())) + } + Some(StaticStrings::NpRavel) => { + args.check_zero_args("ndarray.ravel", vm.heap)?; + self.get(vm.heap).flatten(vm.heap) + } + Some(StaticStrings::NpItem) => { + args.check_zero_args("ndarray.item", vm.heap)?; + self.get(vm.heap).item() + } + Some(StaticStrings::NpCumprod) => { + args.check_zero_args("ndarray.cumprod", vm.heap)?; + self.get(vm.heap).cumprod(vm.heap) + } + Some(StaticStrings::NpSqueeze) => { + args.check_zero_args("ndarray.squeeze", vm.heap)?; + self.get(vm.heap).squeeze(vm.heap) + } + Some(StaticStrings::NpTake) => { + let idx_val = args.get_one_arg("ndarray.take", vm.heap)?; + let result = match &idx_val { + Value::Ref(other_id) => { + if let HeapData::NdArray(other) = vm.heap.get(*other_id) { + self.get(vm.heap).take_indices(other, vm.heap) + } else { + Err(ExcType::type_error("take() requires an ndarray of indices")) + } + } + _ => Err(ExcType::type_error("take() requires an ndarray of indices")), + }; + idx_val.drop_with_heap(vm); + result + } + Some(StaticStrings::NpDiagonal) => { + args.check_zero_args("ndarray.diagonal", vm.heap)?; + self.get(vm.heap).diagonal(vm.heap) + } + Some(StaticStrings::NpTrace) => { + args.check_zero_args("ndarray.trace", vm.heap)?; + self.get(vm.heap).trace() + } + Some(StaticStrings::NpFill) => { + let arg = args.get_one_arg("ndarray.fill", vm.heap)?; + let result = extract_f64(&arg); + arg.drop_with_heap(vm); + let val = result?; + self.get_mut(vm.heap).fill(val); + Ok(Value::None) + } + Some(StaticStrings::NpCompress) => { + let cond_val = args.get_one_arg("ndarray.compress", vm.heap)?; + let result = match &cond_val { + Value::Ref(other_id) => { + if let HeapData::NdArray(cond) = vm.heap.get(*other_id) { + self.get(vm.heap).compress(cond, vm.heap) + } else { + Err(ExcType::type_error("compress() requires a boolean ndarray condition")) + } + } + _ => Err(ExcType::type_error("compress() requires a boolean ndarray condition")), + }; + cond_val.drop_with_heap(vm); + result + } + Some(StaticStrings::NpRepeat) => { + let arg = args.get_one_arg("ndarray.repeat", vm.heap)?; + let result = extract_f64(&arg); + arg.drop_with_heap(vm); + #[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "repeat count from user" + )] + let n = result? as usize; + self.get(vm.heap).repeat_array(n, vm.heap) + } + Some(StaticStrings::NpNonzero) => { + args.check_zero_args("ndarray.nonzero", vm.heap)?; + self.get(vm.heap).nonzero_method(vm.heap) + } + Some(StaticStrings::NpSwapaxes) => { + let pos = args.into_pos_only("ndarray.swapaxes", vm.heap)?; + let result = if pos.as_slice().len() >= 2 { + match (extract_f64(&pos.as_slice()[0]), extract_f64(&pos.as_slice()[1])) { + (Ok(a), Ok(b)) => + { + #[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "axis from user" + )] + self.get(vm.heap).swapaxes(a as usize, b as usize, vm.heap) + } + (Err(err), _) | (_, Err(err)) => Err(err), + } + } else { + Err(ExcType::type_error("swapaxes() requires two arguments")) + }; + pos.drop_with_heap(vm); + result + } + _ => { + args.drop_with_heap(vm); + return Err(ExcType::attribute_error(Type::NdArray, attr.as_str(vm.interns))); + } + }; + result.map(CallResult::Value) + } +} + +// =========================== +// HeapItem implementation +// =========================== + +impl HeapItem for NdArray { + fn py_estimate_size(&self) -> usize { + size_of::() + self.data.capacity() * size_of::() + self.shape.capacity() * size_of::() + } + + fn py_dec_ref_ids(&mut self, _stack: &mut Vec) { + // NdArray is a leaf type — stores only f64 data, no heap references. + } +} + +// =========================== +// Helper functions +// =========================== + +/// Returns `sum()` with dtype-appropriate return type. +fn call_sum(arr: &NdArray) -> Value { + let s = arr.sum(); + match arr.dtype() { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is intended for int sum" + )] + NdArrayDtype::Int64 => Value::Int(s as i64), + NdArrayDtype::Float64 | NdArrayDtype::Bool => Value::Float(s), + } +} + +/// Returns `prod()` with dtype-appropriate return type. +fn call_prod_method(arr: &NdArray) -> Value { + let p = arr.prod(); + match arr.dtype() { + #[expect( + clippy::cast_possible_truncation, + reason = "f64 to i64 truncation is intended for int prod" + )] + NdArrayDtype::Int64 => Value::Int(p as i64), + NdArrayDtype::Float64 | NdArrayDtype::Bool => Value::Float(p), + } +} + +/// Returns `min()` as a scalar matching the array's dtype. +fn call_min(arr: &NdArray) -> RunResult { + let m = arr.min_val()?; + Ok(arr.element_to_value(m)) +} + +/// Returns `max()` as a scalar matching the array's dtype. +fn call_max(arr: &NdArray) -> RunResult { + let m = arr.max_val()?; + Ok(arr.element_to_value(m)) +} + +/// Handles `reshape(*shape)` — takes shape as positional args. +fn call_reshape(arr: &NdArray, args: &[Value], heap: &Heap) -> RunResult { + let mut new_shape = Vec::with_capacity(args.len()); + for arg in args { + match arg { + #[expect( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + reason = "shape dimensions from user input" + )] + Value::Int(n) => new_shape.push(*n as usize), + _ => { + return Err(ExcType::type_error("an integer is required for reshape dimensions")); + } + } + } + arr.reshape(new_shape, heap) +} + +/// Extracts an f64 from a Python numeric value. +/// +/// Used by ndarray methods (like `clip`) that accept numeric arguments from Python. +fn extract_f64(value: &Value) -> RunResult { + match value { + #[expect( + clippy::cast_precision_loss, + reason = "i64 to f64 precision loss acceptable for numeric args" + )] + Value::Int(n) => Ok(*n as f64), + Value::Float(f) => Ok(*f), + Value::Bool(true) => Ok(1.0), + Value::Bool(false) => Ok(0.0), + _ => Err(ExcType::type_error( + "ndarray numeric argument must be int, float, or bool", + )), + } +} + +/// Comparison function that sorts NaN values to the end, matching NumPy's sort behavior. +/// +/// Non-NaN values are compared normally. NaN is treated as greater than any non-NaN value, +/// and two NaN values are considered equal. +/// +/// Takes `&f64` so it can be passed directly to `[f64]::sort_by`. +#[expect( + clippy::trivially_copy_pass_by_ref, + reason = "signature required by sort_by(Fn(&T, &T) -> Ordering)" +)] +pub(crate) fn nan_last_cmp(a: &f64, b: &f64) -> Ordering { + match (a.is_nan(), b.is_nan()) { + (true, true) => Ordering::Equal, + (true, false) => Ordering::Greater, + (false, true) => Ordering::Less, + (false, false) => a.partial_cmp(b).unwrap_or(Ordering::Equal), + } +} + +/// Converts a potentially negative index to a positive one, or returns an error. +/// +/// Supports Python-style negative indexing: `-1` is the last element, etc. +fn resolve_index(index: i64, axis_len: usize) -> RunResult { + #[expect(clippy::cast_possible_wrap, reason = "axis_len won't exceed i64::MAX")] + let resolved = if index < 0 { + let pos = axis_len as i64 + index; + if pos < 0 { + return Err(SimpleException::new_msg(ExcType::IndexError, "index out of range").into()); + } + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "pos is guaranteed non-negative above" + )] + let r = pos as usize; + r + } else { + #[expect( + clippy::cast_sign_loss, + clippy::cast_possible_truncation, + reason = "index is guaranteed non-negative" + )] + let r = index as usize; + r + }; + if resolved >= axis_len { + return Err(SimpleException::new_msg(ExcType::IndexError, "index out of range").into()); + } + Ok(resolved) +} + +/// Returns how many positions a normalized slice assignment will touch. +/// +/// `Slice::indices()` already clamps `start` and `stop` to the array bounds. +/// This helper mirrors the iteration loops in `py_setitem()` so RHS validation +/// rejects only assignments that would actually write a non-broadcastable shape. +fn slice_assignment_len(start: i64, stop: i64, step: i64) -> usize { + let len = if step > 0 { + if start >= stop { + 0 + } else { + ((stop - start - 1) / step) + 1 + } + } else if start <= stop { + 0 + } else { + ((start - stop - 1) / -step) + 1 + }; + usize::try_from(len).expect("normalized slice length cannot be negative") +} + +/// Validates ndarray RHS length for assignment into a one-dimensional slice. +/// +/// NumPy allows an ndarray RHS when it exactly matches the target length, when a +/// single element can broadcast across a non-empty target, or when the target is +/// empty and the assignment is therefore a no-op. +fn validate_slice_assignment_length(rhs_len: usize, target_len: usize) -> RunResult<()> { + if target_len == 0 || rhs_len == target_len || rhs_len == 1 { + Ok(()) + } else { + Err(SimpleException::new_msg( + ExcType::ValueError, + format!( + "could not broadcast input array from shape {} into shape {}", + format_broadcast_shape(&[rhs_len]), + format_broadcast_shape(&[target_len]) + ), + ) + .into()) + } +} + +/// Creates an ndarray from a Python list Value (potentially nested). +/// +/// Recursively traverses nested lists to determine shape and flatten data. +pub(crate) fn ndarray_from_list(value: &Value, heap: &Heap) -> RunResult { + let mut data = Vec::new(); + let mut shape = Vec::new(); + let mut has_float = false; + let mut has_int = false; + let mut has_bool = false; + collect_from_value( + value, + heap, + &mut data, + &mut shape, + 0, + &mut has_float, + &mut has_int, + &mut has_bool, + )?; + + let dtype = if has_float { + NdArrayDtype::Float64 + } else if has_int { + NdArrayDtype::Int64 + } else if has_bool { + NdArrayDtype::Bool + } else { + // Empty array defaults to float64, matching NumPy's behavior + NdArrayDtype::Float64 + }; + + Ok(NdArray::new(data, shape, dtype)) +} + +/// Recursively collects numeric data from a nested list/value structure. +/// +/// Tracks which scalar types are present (`has_float`, `has_int`, `has_bool`) so the +/// caller can determine the correct dtype: float > int > bool, matching NumPy's +/// type promotion rules. +#[expect(clippy::too_many_arguments)] +fn collect_from_value( + value: &Value, + heap: &Heap, + data: &mut Vec, + shape: &mut Vec, + depth: usize, + has_float: &mut bool, + has_int: &mut bool, + has_bool: &mut bool, +) -> RunResult<()> { + match value { + Value::Int(n) => { + *has_int = true; + data.push(*n as f64); + Ok(()) + } + Value::Float(f) => { + *has_float = true; + data.push(*f); + Ok(()) + } + Value::Bool(b) => { + *has_bool = true; + data.push(if *b { 1.0 } else { 0.0 }); + Ok(()) + } + Value::Ref(heap_id) => match heap.get(*heap_id) { + HeapData::List(list) => { + let items = list.as_slice(); + let len = items.len(); + + if depth >= shape.len() { + shape.push(len); + } else if shape[depth] != len { + return Err(SimpleException::new_msg( + ExcType::ValueError, + "setting an array element with a sequence", + ) + .into()); + } + + for item in items { + collect_from_value(item, heap, data, shape, depth + 1, has_float, has_int, has_bool)?; + } + Ok(()) + } + _ => Err(ExcType::type_error("cannot create ndarray from this type")), + }, + _ => Err(ExcType::type_error("cannot create ndarray from this type")), + } +} + +/// Determines the result dtype when combining two dtypes. +/// +/// Follows NumPy's type promotion: if either operand is float, result is float. +pub(crate) fn promote_dtype(a: NdArrayDtype, b: NdArrayDtype) -> NdArrayDtype { + match (a, b) { + (NdArrayDtype::Float64, _) | (_, NdArrayDtype::Float64) => NdArrayDtype::Float64, + _ => NdArrayDtype::Int64, + } +} + +/// Determines result dtype when combining an array dtype with a scalar. +/// +/// `scalar_is_float` indicates whether the Python value was a `float` (as opposed to `int`). +/// This is necessary because `1.0` and `1` are both `f64` internally, but NumPy promotes +/// `int_arr * 1.0` to float64 while `int_arr * 1` stays int64. +pub(crate) fn promote_dtype_with_scalar(arr_dtype: NdArrayDtype, scalar_is_float: bool) -> NdArrayDtype { + if arr_dtype == NdArrayDtype::Float64 || scalar_is_float { + NdArrayDtype::Float64 + } else { + arr_dtype + } +} + +/// Validates that the dtype supports bitwise operations. +/// +/// NumPy raises `TypeError` for bitwise ops on float arrays. Bool and Int64 are supported. +fn check_bitwise_dtype(dtype: NdArrayDtype, op_symbol: &str) -> RunResult<()> { + if dtype == NdArrayDtype::Float64 { + return Err(SimpleException::new_msg( + ExcType::TypeError, + format!("ufunc 'bitwise_{op_symbol}' not supported for the input types"), + ) + .into()); + } + Ok(()) +} + +/// Python-compatible modulo: result has the same sign as the divisor. +fn py_mod(a: f64, b: f64) -> f64 { + let r = a % b; + if r != 0.0 && ((r > 0.0) != (b > 0.0)) { r + b } else { r } +} diff --git a/crates/monty/src/types/type.rs b/crates/monty/src/types/type.rs index 496516189..1374776f2 100644 --- a/crates/monty/src/types/type.rs +++ b/crates/monty/src/types/type.rs @@ -71,6 +71,12 @@ pub enum Type { RePattern, /// A regex match result from `re.match()` / `re.search()` etc. - displays as "re.Match" ReMatch, + /// A NumPy ndarray - displays as "numpy.ndarray" while `__name__` remains "ndarray". + NdArray, + /// NumPy's public flat iterator type object for `ndarray.flat` results. + FlatIter, + /// NumPy's public ufunc type object for implemented ufunc-like callables. + Ufunc, } impl fmt::Display for Type { @@ -113,11 +119,27 @@ impl fmt::Display for Type { Self::Property => f.write_str("property"), Self::RePattern => f.write_str("re.Pattern"), Self::ReMatch => f.write_str("re.Match"), + Self::NdArray => f.write_str("numpy.ndarray"), + Self::FlatIter => f.write_str("flatiter"), + Self::Ufunc => f.write_str("ufunc"), } } } impl Type { + /// Returns the value exposed by a type object's `__name__` attribute. + /// + /// Most Monty type displays are already unqualified names, but NumPy's + /// ndarray type needs a qualified display string for repr/error messages + /// while keeping the CPython-compatible `ndarray` type name. + #[must_use] + pub(crate) fn dunder_name(self) -> String { + match self { + Self::NdArray => "ndarray".to_string(), + _ => self.to_string(), + } + } + /// Returns the Python source-level name for builtin types that can be called directly. /// /// This differs from `Display` for internal representation-only names such as diff --git a/crates/monty/src/value.rs b/crates/monty/src/value.rs index 254c0f9bf..a47aa46df 100644 --- a/crates/monty/src/value.rs +++ b/crates/monty/src/value.rs @@ -22,7 +22,7 @@ use crate::{ hash::HashValue, heap::{ContainsHeap, DropWithHeap, Heap, HeapData, HeapGuard, HeapId, HeapReadOutput}, intern::{BytesId, FunctionId, Interns, LongIntId, StaticStrings, StringId}, - modules::ModuleFunctions, + modules::{ModuleFunctions, numpy::numpy_marker_getitem}, resource::{ ResourceError, ResourceTracker, check_div_size, check_lshift_size, check_mult_size, check_pow_size, check_repeat_size, @@ -31,6 +31,7 @@ use crate::{ Bytes, List, LongInt, Property, PyTrait, Str, Type, allocate_tuple, bytes::{bytes_repr_fmt, get_byte_at_index}, long_int::check_bits_str_digits_limit, + ndarray::NdArrayDtype, path, slice::slice_collect_iterator, str::{allocate_char, get_char_at_index, string_repr_fmt}, @@ -143,6 +144,7 @@ impl PyTrait<'_> for Value { Self::InternString(_) => Type::Str, Self::InternBytes(_) => Type::Bytes, Self::Builtin(c) => c.py_type(), + Self::ModuleFunction(ModuleFunctions::Numpy(function)) if function.is_ufunc_like() => Type::Ufunc, Self::ModuleFunction(_) => Type::BuiltinFunction, Self::DefFunction(_) | Self::ExtFunction(_) => Type::Function, Self::Marker(m) => m.py_type(), @@ -220,7 +222,6 @@ impl PyTrait<'_> for Value { let right = vm.heap.read(*id2); left.py_eq(&right, vm) } - // Builtins equality - just check the enums are equal (Self::Builtin(b1), Self::Builtin(b2)) => Ok(b1 == b2), // Module functions equality @@ -1337,6 +1338,13 @@ impl PyTrait<'_> for Value { let interns = vm.interns; match self { Self::Ref(id) => vm.heap.read(*id).py_getitem(key, vm), + Self::Marker(marker) => { + if let Some(value) = numpy_marker_getitem(marker.0, key, vm)? { + Ok(value) + } else { + Err(ExcType::type_error_not_sub(self.py_type(vm))) + } + } Self::InternString(string_id) => { // Check for slice first if let Self::Ref(key_id) = key @@ -1649,6 +1657,12 @@ impl Value { } HeapReadOutput::Set(set) => set.contains(item, vm), HeapReadOutput::FrozenSet(fset) => fset.contains(item, vm), + HeapReadOutput::NdArray(arr) => { + let arr = arr.get(vm.heap); + let dtype = arr.dtype(); + let needle = ndarray_contains_needle(item, dtype, vm); + Ok(needle.is_some_and(|n| arr.data().contains(&n))) + } HeapReadOutput::Str(s) => { let s_str = s.get(vm.heap).as_str(); str_contains(s_str, item, vm.heap, vm.interns) @@ -1716,7 +1730,7 @@ impl Value { |ss| ss == StaticStrings::DunderName, ); if is_dunder_name { - let name_str = t.to_string(); + let name_str = t.dunder_name(); let str_id = vm.heap.allocate(HeapData::Str(Str::from(name_str)))?; return Ok(CallResult::Value(Self::Ref(str_id))); } @@ -2148,6 +2162,8 @@ impl BitwiseOp { /// provide functionality in the sandboxed environment /// - Typing constructs from the `typing` module that are imported for type hints but /// don't need runtime functionality +/// - NumPy index-trick sentinels that provide subscription behavior without adding +/// heap object kinds /// /// Wraps a `StaticStrings` variant to leverage its string conversion capabilities. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] @@ -2158,7 +2174,7 @@ impl Marker { /// /// System markers (stdout, stderr) are `TextIOWrapper`. /// `typing.Union` has type `type` (matching CPython). - /// Other typing markers (Any, Optional, etc.) are `_SpecialForm`. + /// Other typing markers (Any, Optional, etc.) and NumPy index-trick markers are `_SpecialForm`. pub(crate) fn py_type(self) -> Type { match self.0 { StaticStrings::Stdout | StaticStrings::Stderr => Type::TextIOWrapper, @@ -2171,13 +2187,20 @@ impl Marker { /// /// System markers have special repr formats ("", ""). /// `typing.Union` uses `` format (matching CPython). - /// Other typing markers are prefixed with "typing." (e.g., "typing.Any"). + /// Other typing markers are prefixed with "typing." (e.g., "typing.Any"); + /// NumPy index-trick markers are prefixed with "numpy.". pub(crate) fn py_repr_fmt(self, f: &mut impl Write) -> fmt::Result { let s: &'static str = self.0.into(); match self.0 { StaticStrings::Stdout => f.write_str("")?, StaticStrings::Stderr => f.write_str("")?, StaticStrings::UnionType => f.write_str("")?, + StaticStrings::NpIndexExp + | StaticStrings::NpSIndex + | StaticStrings::NpMgrid + | StaticStrings::NpOgrid + | StaticStrings::NpRIndex + | StaticStrings::NpCIndex => write!(f, "numpy.{s}")?, _ => write!(f, "typing.{s}")?, } Ok(()) @@ -2406,6 +2429,34 @@ fn extract_bigint(value: &Value, heap: &Heap) -> Option) -> Option { + match item { + Value::Int(i) => Some(*i as f64), + Value::Float(f) => Some(*f), + Value::Bool(b) => Some(if *b { 1.0 } else { 0.0 }), + Value::Ref(id) => { + let HeapData::LongInt(li) = vm.heap.get(*id) else { + return None; + }; + match dtype { + NdArrayDtype::Float64 => li.inner().to_f64(), + NdArrayDtype::Int64 => li.to_i64().map(|i| i as f64), + NdArrayDtype::Bool => li.to_i64().and_then(|i| match i { + 0 | 1 => Some(i as f64), + _ => None, + }), + } + } + _ => None, + } +} + /// Extracts and clones the `(key, value)` probe accepted by `dict_items.__contains__`. /// /// CPython treats only 2-tuples as valid probes for items-view membership. Monty diff --git a/crates/monty/test_cases/lambda__all.py b/crates/monty/test_cases/lambda__all.py index 3de54a20d..41069f2b4 100644 --- a/crates/monty/test_cases/lambda__all.py +++ b/crates/monty/test_cases/lambda__all.py @@ -127,7 +127,7 @@ def make_multiplier(factor): def make_shadowing_lambda(): x = 10 # inner lambda has param x, so outer lambda should NOT capture x from make_shadowing_lambda - return lambda: (lambda x: x + 1) + return lambda: lambda x: x + 1 outer_fn = make_shadowing_lambda() @@ -138,7 +138,7 @@ def make_shadowing_lambda(): def test_inner_lambda_capture(): y = 5 # outer lambda binds y as param, inner lambda captures from outer lambda, not test_inner_lambda_capture - g = lambda y: (lambda: y) + g = lambda y: lambda: y return g(7)() diff --git a/crates/monty/test_cases/numpy__aliases.py b/crates/monty/test_cases/numpy__aliases.py new file mode 100644 index 000000000..4bbfbbedf --- /dev/null +++ b/crates/monty/test_cases/numpy__aliases.py @@ -0,0 +1,844 @@ +# skip-cpython +import numpy as np + + +# === np.absolute -> np.abs === +assert np.absolute(-3) == 3, 'absolute scalar int' +assert np.absolute(-2.5) == 2.5, 'absolute scalar float' +assert np.absolute([-1.5, 0.0, 2.5]).tolist() == [1.5, 0.0, 2.5], 'absolute float list' +assert np.absolute(np.array([-3, 0, 5])).tolist() == [3, 0, 5], 'absolute 1d array' +assert np.absolute(np.array([[-1, 2], [-3, 4]])).tolist() == [[1, 2], [3, 4]], 'absolute 2d array' +assert np.absolute(np.array([])).tolist() == [], 'absolute empty array' + + +# === np.amax / np.amin -> np.max / np.min === +assert np.amax(5) == 5, 'amax scalar' +assert np.amin(-2) == -2, 'amin scalar' +assert np.amax([1.0, 5.0, 3.0]) == 5.0, 'amax float list' +assert np.amin([1.0, -5.0, 3.0]) == -5.0, 'amin float list' +assert np.amax(np.array([1, 5, 3])) == 5, 'amax 1d array' +assert np.amin(np.array([1, -5, 3])) == -5, 'amin 1d array' +assert np.amax(np.array([[1, 4], [2, 3]])) == 4, 'amax 2d array' +assert np.amin(np.array([[1, 4], [2, 3]])) == 1, 'amin 2d array' + + +# === np.asin / np.acos / np.atan aliases === +assert np.asin(0.0) == 0.0, 'asin scalar zero' +assert abs(np.asin(1.0) - np.pi / 2) < 1e-12, 'asin scalar one' +assert np.acos(1.0) == 0.0, 'acos scalar one' +assert abs(np.acos(0.0) - np.pi / 2) < 1e-12, 'acos scalar zero' +assert np.atan(0.0) == 0.0, 'atan scalar zero' +assert abs(np.atan(1.0) - np.pi / 4) < 1e-12, 'atan scalar one' +assert np.asin([0.0, 1.0]).tolist()[0] == 0.0, 'asin list first' +assert abs(np.asin([0.0, 1.0]).tolist()[1] - np.pi / 2) < 1e-12, 'asin list second' +acos_result = np.acos(np.array([[1.0, 0.0], [0.0, 1.0]])).tolist() +assert acos_result[0][0] == 0.0, 'acos 2d first' +assert abs(acos_result[0][1] - np.pi / 2) < 1e-12, 'acos 2d second' +atan_result = np.atan(np.array([0.0, 1.0])).tolist() +assert atan_result[0] == 0.0, 'atan 1d first' +assert abs(atan_result[1] - np.pi / 4) < 1e-12, 'atan 1d second' +assert np.atan(np.array([])).tolist() == [], 'atan empty array' + + +# === additional inverse aliases === +assert np.asinh(0.0) == 0.0, 'asinh scalar zero' +assert abs(np.asinh(1.0) - 0.881373587019543) < 1e-12, 'asinh scalar one' +assert np.acosh(1.0) == 0.0, 'acosh scalar one' +assert abs(np.acosh(2.0) - 1.3169578969248166) < 1e-12, 'acosh scalar two' +assert np.atanh(0.0) == 0.0, 'atanh scalar zero' +assert abs(np.atanh(0.5) - 0.5493061443340548) < 1e-12, 'atanh scalar half' +assert np.atan2(0.0, 1.0) == 0.0, 'atan2 scalar zero' +assert abs(np.atan2(np.array([1.0]), np.array([1.0])).tolist()[0] - np.pi / 4) < 1e-12, 'atan2 array' +assert np.angle(1.0) == 0.0, 'angle positive real scalar' +assert np.angle(-1.0) == np.pi, 'angle negative real scalar' +assert np.angle(-0.0) == np.pi, 'angle negative zero scalar' +assert np.angle([1.0, -1.0, 0.0, -0.0]).tolist() == [0.0, np.pi, 0.0, np.pi], 'angle real list' +assert np.angle([-1.0], True).tolist() == [180.0], 'angle degrees' + + +# === np.around -> np.round === +assert np.around(1.234, 2) == 1.23, 'around scalar' +assert np.around([1.234, 5.678], 1).tolist() == [1.2, 5.7], 'around list' +assert np.around(np.array([1.234, 5.678]), 1).tolist() == [1.2, 5.7], 'around 1d array' +assert np.around(np.array([[1.234, 5.678], [9.012, 3.456]]), 1).tolist() == [ + [1.2, 5.7], + [9.0, 3.5], +], 'around 2d array' +assert np.around(np.array([]), 1).tolist() == [], 'around empty array' + + +# === np.asanyarray -> np.asarray === +a = np.asanyarray([1, 2, 3]) +assert a.tolist() == [1, 2, 3], 'asanyarray list' +assert a.shape == (3,), 'asanyarray list shape' +b = np.array([[1, 2], [3, 4]]) +c = np.asanyarray(b) +assert c.tolist() == [[1, 2], [3, 4]], 'asanyarray ndarray values' +assert c.shape == (2, 2), 'asanyarray ndarray shape' +assert np.asanyarray([]).tolist() == [], 'asanyarray empty list' + + +# === common binary ufuncs === +assert np.add(1, 2) == 3, 'add scalar' +assert np.add(1, 2.5) == 3.5, 'add mixed scalar promotes to float' +assert np.add([1, 2], [3, 4]).tolist() == [4, 6], 'add lists' +assert np.add(np.array([1, 2]), [3, 4]).tolist() == [4, 6], 'add array and list' +assert np.add(np.array([[1, 2], [3, 4]]), 10).tolist() == [[11, 12], [13, 14]], 'add scalar broadcast' +assert np.subtract(10, np.array([1, 2])).tolist() == [9, 8], 'subtract scalar left broadcast' +assert np.multiply(2, [3, 4]).tolist() == [6, 8], 'multiply scalar left list' +assert np.divide(9, np.array([3, 2])).tolist() == [3.0, 4.5], 'divide scalar left array' +assert np.add(np.array([]), np.array([])).tolist() == [], 'add empty arrays' +assert np.subtract(np.array([5, 7]), np.array([2, 3])).tolist() == [3, 4], 'subtract arrays' +assert np.multiply(np.array([2, 3]), np.array([4, 5])).tolist() == [8, 15], 'multiply arrays' +assert np.divide(np.array([5, 9]), np.array([2, 3])).tolist() == [2.5, 3.0], 'divide arrays' +assert np.true_divide(5, 2) == 2.5, 'true_divide scalar' +assert np.floor_divide(np.array([5, 9]), np.array([2, 3])).tolist() == [2, 3], 'floor_divide arrays' +assert np.floor_divide(9, np.array([2, 4])).tolist() == [4, 2], 'floor_divide scalar left array' +assert np.mod(np.array([-3, 4]), np.array([2, 3])).tolist() == [1, 1], 'mod arrays' +assert np.mod(3, np.array([-2, 2])).tolist() == [-1, 1], 'mod scalar left signed divisors' +assert np.remainder(-3, 2) == 1, 'remainder scalar' +assert np.pow([2, 3], [3, 2]).tolist() == [8, 9], 'pow alias lists' + + +# === comparison ufuncs === +assert np.equal(1, 1) == True, 'equal scalar true' +assert np.equal(float('nan'), float('nan')) == False, 'equal scalar nan' +assert np.not_equal(1, 2) == True, 'not_equal scalar true' +assert np.not_equal(float('nan'), float('nan')) == True, 'not_equal scalar nan' +assert np.greater(3, 2) == True, 'greater scalar true' +assert np.greater_equal(2, 2) == True, 'greater_equal scalar true' +assert np.less(1, 2) == True, 'less scalar true' +assert np.less_equal(2, 2) == True, 'less_equal scalar true' +assert np.equal([1, 2, 3], [1, 0, 3]).tolist() == [True, False, True], 'equal lists' +assert np.not_equal(np.array([1, 2]), np.array([1, 3])).tolist() == [False, True], 'not_equal arrays' +assert np.greater(np.array([[1, 4], [5, 2]]), 3).tolist() == [[False, True], [True, False]], 'greater 2d' +assert np.greater_equal(3, np.array([2, 3, 4])).tolist() == [True, True, False], 'greater_equal scalar left' +assert np.less(3, np.array([2, 3, 4])).tolist() == [False, False, True], 'less scalar left' +assert np.less_equal(np.array([]), np.array([])).tolist() == [], 'less_equal empty arrays' + + +# === function aliases and shape helpers === +assert np.concat([np.array([1, 2]), np.array([3, 4])]).tolist() == [1, 2, 3, 4], 'concat alias' +assert np.cumulative_sum(np.array([1, 2, 3])).tolist() == [1, 3, 6], 'cumulative_sum alias' +assert np.cumulative_prod(np.array([2, 3, 4])).tolist() == [2, 6, 24], 'cumulative_prod alias' +assert np.shape(5) == (), 'shape scalar' +assert np.shape([1, 2, 3]) == (3,), 'shape list' +assert np.shape(np.array([[1, 2], [3, 4]])) == (2, 2), 'shape 2d array' +assert np.size(5) == 1, 'size scalar' +assert np.size(np.array([[1, 2], [3, 4]])) == 4, 'size 2d array' +assert np.size(np.array([])) == 0, 'size empty array' +assert np.ndim(5) == 0, 'ndim scalar' +assert np.ndim([1, 2, 3]) == 1, 'ndim list' +assert np.ndim(np.array([[1, 2], [3, 4]])) == 2, 'ndim 2d array' + + +# === shape and index helpers === +assert np.atleast_1d(5).tolist() == [5], 'atleast_1d scalar' +assert np.atleast_2d([1, 2, 3]).tolist() == [[1, 2, 3]], 'atleast_2d list' +assert np.atleast_3d([1, 2, 3]).tolist() == [[[1], [2], [3]]], 'atleast_3d list' +atleast_a, atleast_b = np.atleast_1d(1, [2, 3]) +assert atleast_a.tolist() == [1], 'atleast_1d multi scalar' +assert atleast_b.tolist() == [2, 3], 'atleast_1d multi list' + +diag_row, diag_col = np.diag_indices(3) +assert diag_row.tolist() == [0, 1, 2], 'diag_indices first axis' +assert diag_col.tolist() == [0, 1, 2], 'diag_indices second axis' +diag_from_row, diag_from_col = np.diag_indices_from(np.ones((3, 3))) +assert diag_from_row.tolist() == [0, 1, 2], 'diag_indices_from first axis' +assert diag_from_col.tolist() == [0, 1, 2], 'diag_indices_from second axis' + +tril_row, tril_col = np.tril_indices(3) +assert tril_row.tolist() == [0, 1, 1, 2, 2, 2], 'tril_indices rows' +assert tril_col.tolist() == [0, 0, 1, 0, 1, 2], 'tril_indices cols' +tril_from_row, tril_from_col = np.tril_indices_from(np.ones((2, 3)), 1) +assert tril_from_row.tolist() == [0, 0, 1, 1, 1], 'tril_indices_from rows' +assert tril_from_col.tolist() == [0, 1, 0, 1, 2], 'tril_indices_from cols' + +triu_row, triu_col = np.triu_indices(3) +assert triu_row.tolist() == [0, 0, 0, 1, 1, 2], 'triu_indices rows' +assert triu_col.tolist() == [0, 1, 2, 1, 2, 2], 'triu_indices cols' +triu_from_row, triu_from_col = np.triu_indices_from(np.ones((2, 3)), -1) +assert triu_from_row.tolist() == [0, 0, 0, 1, 1, 1], 'triu_indices_from rows' +assert triu_from_col.tolist() == [0, 1, 2, 0, 1, 2], 'triu_indices_from cols' + +grid = np.indices((2, 3)) +assert grid.shape == (2, 2, 3), 'indices shape' +assert grid.tolist() == [[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], 'indices values' + +unravel_row, unravel_col = np.unravel_index([5, 6], (3, 4)) +assert unravel_row.tolist() == [1, 1], 'unravel_index rows' +assert unravel_col.tolist() == [1, 2], 'unravel_index cols' +scalar_row, scalar_col = np.unravel_index(5, (3, 4)) +assert scalar_row == 1, 'unravel_index scalar row' +assert scalar_col == 1, 'unravel_index scalar col' +assert np.ravel_multi_index(([1, 2], [1, 3]), (3, 4)).tolist() == [5, 11], 'ravel_multi_index arrays' +assert np.ravel_multi_index((1, 1), (3, 4)) == 5, 'ravel_multi_index scalar' +assert list(np.ndindex()) == [()], 'ndindex empty shape' +assert list(np.ndindex(2, 3)) == [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2)], 'ndindex two dims' +assert list(np.ndindex((2, 1))) == [(0, 0), (1, 0)], 'ndindex tuple shape' +nd_iter_arr = np.array([[10, 20], [30, 40]]) +assert list(np.ndenumerate(nd_iter_arr)) == [ + ((0, 0), 10), + ((0, 1), 20), + ((1, 0), 30), + ((1, 1), 40), +], 'ndenumerate matrix' +assert list(np.ndenumerate(5)) == [((), 5)], 'ndenumerate scalar' +assert list(np.nditer(nd_iter_arr)) == [10, 20, 30, 40], 'nditer matrix order' +assert list(np.nditer(5)) == [5], 'nditer scalar' + +diagflat = np.diagflat([[1, 2], [3, 4]]) +assert diagflat.shape == (4, 4), 'diagflat flattened shape' +assert diagflat.tolist() == [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], 'diagflat values' +assert np.diagflat([1, 2], 1).tolist() == [[0, 1, 0], [0, 0, 2], [0, 0, 0]], 'diagflat positive k' +assert np.diagflat([1, 2], -1).tolist() == [[0, 0, 0], [1, 0, 0], [0, 2, 0]], 'diagflat negative k' + +ix_row, ix_col = np.ix_([0, 2], [1, 3, 4]) +assert ix_row.shape == (2, 1), 'ix_ first shape' +assert ix_col.shape == (1, 3), 'ix_ second shape' +assert ix_row.tolist() == [[0], [2]], 'ix_ first values' +assert ix_col.tolist() == [[1, 3, 4]], 'ix_ second values' + +mask_row, mask_col = np.mask_indices(3, np.triu, 1) +assert mask_row.tolist() == [0, 0, 1], 'mask_indices upper rows' +assert mask_col.tolist() == [1, 2, 2], 'mask_indices upper cols' +lower_row, lower_col = np.mask_indices(3, np.tril, -1) +assert lower_row.tolist() == [1, 2, 2], 'mask_indices lower rows' +assert lower_col.tolist() == [0, 0, 1], 'mask_indices lower cols' + +memory_arr = np.array([[1, 2], [3, 4]]) +memory_alias = memory_arr +memory_copy = memory_arr.copy() +assert np.isfortran(memory_arr) == False, 'isfortran row-major array' +assert np.shares_memory(memory_arr, memory_alias) == True, 'shares_memory same ndarray ref' +assert np.shares_memory(memory_arr, memory_copy) == False, 'shares_memory copied ndarray' +assert np.may_share_memory(memory_arr, memory_alias) == True, 'may_share_memory same ndarray ref' +assert np.may_share_memory(memory_arr, memory_copy) == False, 'may_share_memory copied ndarray' + + +# === module-level manipulation wrappers === +matrix = np.array([[1, 2, 3], [4, 5, 6]]) +assert np.take(matrix, [0, -1, 2]).tolist() == [1, 6, 3], 'take flattened indices' +assert np.take(matrix, 2) == 3, 'take flattened scalar index' +assert np.compress([1, 0, 1, 0, 0, 1], matrix).tolist() == [1, 3, 6], 'compress flattened condition' +assert np.swapaxes(matrix, 0, 1).tolist() == [[1, 4], [2, 5], [3, 6]], 'swapaxes 2d' +assert np.swapaxes(matrix, -1, -2).tolist() == [[1, 4], [2, 5], [3, 6]], 'swapaxes negative axes' +assert np.permute_dims(matrix).tolist() == [[1, 4], [2, 5], [3, 6]], 'permute_dims default' +assert np.permute_dims(matrix, (0, 1)).tolist() == [[1, 2, 3], [4, 5, 6]], 'permute_dims identity' +assert np.matrix_transpose(matrix).tolist() == [[1, 4], [2, 5], [3, 6]], 'matrix_transpose 2d' +try: + np.matrix_transpose(np.array([1, 2, 3])) + assert False, 'expected matrix_transpose to reject 1d input' +except ValueError as exc: + assert str(exc) == 'Input array must be at least 2-dimensional, but it is 1', 'matrix_transpose 1d error' + +assert np.rot90(matrix).tolist() == [[3, 6], [2, 5], [1, 4]], 'rot90 one turn' +assert np.rot90(matrix, 2).tolist() == [[6, 5, 4], [3, 2, 1]], 'rot90 two turns' +assert np.rot90(matrix, -1).tolist() == [[4, 1], [5, 2], [6, 3]], 'rot90 negative turn' + +cube = np.arange(24).reshape(2, 3, 4) +moved = np.moveaxis(cube, 0, 2) +assert moved.shape == (3, 4, 2), 'moveaxis shape' +assert moved.tolist()[0][0] == [0, 12], 'moveaxis first vector' +assert moved.tolist()[2][3] == [11, 23], 'moveaxis last vector' +rolled = np.rollaxis(cube, 2, 1) +assert rolled.shape == (2, 4, 3), 'rollaxis shape' +assert rolled.tolist()[0][0] == [0, 4, 8], 'rollaxis first vector' +assert rolled.tolist()[1][3] == [15, 19, 23], 'rollaxis last vector' + +dstack_1d = np.dstack(([1, 2], [3, 4])) +assert dstack_1d.shape == (1, 2, 2), 'dstack 1d shape' +assert dstack_1d.tolist() == [[[1, 3], [2, 4]]], 'dstack 1d values' +dstack_2d = np.dstack(([[1, 2], [3, 4]], [[5, 6], [7, 8]])) +assert dstack_2d.shape == (2, 2, 2), 'dstack 2d shape' +assert dstack_2d.tolist() == [[[1, 5], [2, 6]], [[3, 7], [4, 8]]], 'dstack 2d values' +block_scalar = np.block(1) +assert block_scalar.shape == (), 'block scalar shape' +assert block_scalar.tolist() == 1, 'block scalar value' +assert np.block(np.array([1, 2])).tolist() == [1, 2], 'block array leaf' +assert np.block([1, 2, 3]).tolist() == [1, 2, 3], 'block flat scalars' +assert np.block([[1, 2], [3, 4]]).tolist() == [[1, 2], [3, 4]], 'block nested scalars' +assert np.block([np.array([1, 2]), np.array([3, 4])]).tolist() == [1, 2, 3, 4], 'block flat arrays' +assert np.block([np.array([[1, 2]]), np.array([[3, 4]])]).tolist() == [[1, 2, 3, 4]], 'block flat matrices' +assert np.block( + [ + [np.array([[1, 2]]), np.array([[3]])], + [np.array([[4, 5]]), np.array([[6]])], + ] +).tolist() == [[1, 2, 3], [4, 5, 6]], 'block matrix assembly' +try: + np.block((1, 2)) + assert False, 'expected tuple block layout to fail' +except TypeError as exc: + assert str(exc) == ( + 'arrays is a tuple. Only lists can be used to arrange blocks, and np.block does not allow implicit conversion ' + 'from tuple to ndarray.' + ), 'block tuple layout error' +try: + np.block([1, [2, 3]]) + assert False, 'expected mismatched block depth to fail' +except ValueError as exc: + assert str(exc) == ( + 'List depths are mismatched. First element was at depth 1, but there is an element at depth 2 (arrays[1][0])' + ), 'block depth mismatch error' + +depth_parts = np.dsplit(cube, 2) +assert len(depth_parts) == 2, 'dsplit equal section count' +assert depth_parts[0].shape == (2, 3, 2), 'dsplit first equal shape' +assert depth_parts[0].tolist()[0][0] == [0, 1], 'dsplit first equal values' +assert depth_parts[1].tolist()[1][2] == [22, 23], 'dsplit second equal values' +depth_index_parts = np.dsplit(cube, [1, 3]) +assert len(depth_index_parts) == 3, 'dsplit index section count' +assert depth_index_parts[0].shape == (2, 3, 1), 'dsplit first index shape' +assert depth_index_parts[1].tolist()[0][1] == [5, 6], 'dsplit middle index values' +assert depth_index_parts[2].tolist()[1][2] == [23], 'dsplit final index values' + +unstack_row0, unstack_row1 = np.unstack(matrix) +assert unstack_row0.tolist() == [1, 2, 3], 'unstack axis0 first row' +assert unstack_row1.tolist() == [4, 5, 6], 'unstack axis0 second row' +unstack_col0, unstack_col1, unstack_col2 = np.unstack(matrix, 1) +assert unstack_col0.tolist() == [1, 4], 'unstack axis1 first column' +assert unstack_col1.tolist() == [2, 5], 'unstack axis1 second column' +assert unstack_col2.tolist() == [3, 6], 'unstack axis1 third column' +unstack_scalar0, unstack_scalar1, unstack_scalar2 = np.unstack(np.array([1, 2, 3])) +assert unstack_scalar0 == 1, 'unstack 1d first scalar' +assert unstack_scalar1 == 2, 'unstack 1d second scalar' +assert unstack_scalar2 == 3, 'unstack 1d third scalar' + +diag_mut = np.arange(9).reshape(3, 3) +assert np.fill_diagonal(diag_mut, 7) is None, 'fill_diagonal return' +assert diag_mut.tolist() == [[7, 1, 2], [3, 7, 5], [6, 7, 7]], 'fill_diagonal 2d values' + +put_mut = np.array([0, 1, 2, 3, 4]) +assert np.put(put_mut, [0, -1, 2], [10, 20]) is None, 'put return' +assert put_mut.tolist() == [10, 1, 10, 3, 20], 'put cycles values by index list' + +copy_mut = np.array([0, 1, 2]) +assert np.copyto(copy_mut, 5) is None, 'copyto scalar return' +assert copy_mut.tolist() == [5, 5, 5], 'copyto scalar broadcast' +copy_where = np.array([0, 1, 2]) +assert np.copyto(copy_where, [7, 8, 9], where=[True, False, True]) is None, 'copyto where return' +assert copy_where.tolist() == [7, 1, 9], 'copyto where mask' + +putmask_mut = np.array([0, 1, 2, 3, 4]) +assert np.putmask(putmask_mut, [True, False, True, False, True], [9, 8]) is None, 'putmask return' +assert putmask_mut.tolist() == [9, 1, 9, 3, 9], 'putmask uses flat index cycling' + +place_mut = np.array([0, 1, 2, 3, 4]) +assert np.place(place_mut, [True, False, True, False, True], [5, 6]) is None, 'place return' +assert place_mut.tolist() == [5, 1, 6, 3, 5], 'place uses selected-position cycling' + + +# === linear algebra and numeric wrappers === +assert np.vecdot(np.array([1, 2, 3]), np.array([4, 5, 6])) == 32, 'vecdot 1d' +assert np.matvec(np.array([[1, 2], [3, 4]]), np.array([10, 20])).tolist() == [50, 110], 'matvec 2d 1d' +assert np.vecmat(np.array([10, 20]), np.array([[1, 2], [3, 4]])).tolist() == [70, 100], 'vecmat 1d 2d' +assert np.tensordot(np.array([1, 2]), np.array([3, 4]), axes=0).tolist() == [[3, 4], [6, 8]], 'tensordot outer' +assert np.tensordot(np.array([1, 2]), np.array([3, 4]), axes=1) == 11, 'tensordot vector scalar' +assert np.tensordot(np.array([[1, 2], [3, 4]]), np.array([10, 20]), axes=1).tolist() == [ + 50, + 110, +], 'tensordot matrix vector' +assert np.tensordot( + np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), + np.array([[10, 20], [30, 40]]), + ([1, 2], [0, 1]), +).tolist() == [300, 700], 'tensordot explicit axes' +assert np.einsum('i,i->', np.array([1, 2]), np.array([3, 4])) == 11, 'einsum dot' +assert np.einsum('ij,j->i', np.array([[1, 2], [3, 4]]), np.array([10, 20])).tolist() == [ + 50, + 110, +], 'einsum matrix vector' +assert np.einsum('ij,jk->ik', np.array([[1, 2], [3, 4]]), np.array([[10, 20], [30, 40]])).tolist() == [ + [70, 100], + [150, 220], +], 'einsum matrix multiply' +assert np.einsum('ij->ji', np.array([[1, 2], [3, 4]])).tolist() == [[1, 3], [2, 4]], 'einsum transpose' +assert np.einsum('ii->i', np.array([[1, 2], [3, 4]])).tolist() == [1, 4], 'einsum diagonal' +assert np.einsum('i,j->ij', np.array([1, 2]), np.array([3, 4])).tolist() == [[3, 4], [6, 8]], 'einsum outer' +einsum_path, einsum_path_details = np.einsum_path('ij,jk->ik', np.array([[1, 2]]), np.array([[3], [4]])) +assert einsum_path == ['einsum_path', (0, 1)], 'einsum_path simple path' +assert len(einsum_path_details) > 0, 'einsum_path details are nonempty' +assert np.trapezoid(np.array([1, 2, 3])) == 4.0, 'trapezoid unit spacing' +assert np.trapezoid(np.array([1, 2, 3]), np.array([0, 1, 3])) == 6.5, 'trapezoid x coordinates' +assert np.trapezoid(np.array([1, 2, 3]), None, 2.0) == 8.0, 'trapezoid dx' +assert np.vander(np.array([1, 2, 3])).tolist() == [[1, 1, 1], [4, 2, 1], [9, 3, 1]], 'vander default' +assert np.vander(np.array([1, 2, 3]), 2).tolist() == [[1, 1], [2, 1], [3, 1]], 'vander n' +assert np.vander(np.array([1, 2, 3]), 3, True).tolist() == [ + [1, 1, 1], + [1, 2, 4], + [1, 3, 9], +], 'vander increasing' +assert np.polyadd([1, 2, 3], [10, 20]).tolist() == [1, 12, 23], 'polyadd aligns coefficients' +assert np.polysub([1, 2, 3], [10, 20]).tolist() == [1, -8, -17], 'polysub aligns coefficients' +assert np.polymul([1, 2], [3, 4]).tolist() == [3, 10, 8], 'polymul convolution' +assert np.poly([1, 2, 3]).tolist() == [1.0, -6.0, 11.0, -6.0], 'poly roots to coefficients' +polydiv_exact_q, polydiv_exact_r = np.polydiv([1, 0, -1], [1, -1]) +assert polydiv_exact_q.tolist() == [1.0, 1.0], 'polydiv exact quotient' +assert polydiv_exact_r.tolist() == [0.0], 'polydiv exact remainder' +polydiv_rem_q, polydiv_rem_r = np.polydiv([1, 2, 3], [1, 1]) +assert polydiv_rem_q.tolist() == [1.0, 1.0], 'polydiv quotient with remainder' +assert polydiv_rem_r.tolist() == [2.0], 'polydiv nonzero remainder' +assert np.polyval([1, 0, -1], 2) == 3, 'polyval scalar' +assert np.polyval([1, 0, -1], [0, 1, 2]).tolist() == [-1, 0, 3], 'polyval array' +assert np.polyder([1, 2, 3]).tolist() == [2, 2], 'polyder first derivative' +assert np.polyder([1, 2, 3], 2).tolist() == [2], 'polyder second derivative' +assert np.polyint([2, 2]).tolist() == [1.0, 2.0, 0.0], 'polyint first integral' +assert np.polyint([2, 2], 2).tolist() == [1 / 3, 1.0, 0.0, 0.0], 'polyint second integral' +assert np.kron([1, 2], [10, 20, 30]).tolist() == [10, 20, 30, 20, 40, 60], 'kron 1d' +assert np.kron([[1, 2], [3, 4]], [[0, 5], [6, 7]]).tolist() == [ + [0, 5, 0, 10], + [6, 7, 12, 14], + [0, 15, 0, 20], + [18, 21, 24, 28], +], 'kron 2d' +assert np.cov([1, 2, 3]) == 1.0, 'cov 1d' +assert np.cov([[1, 2, 3], [2, 4, 6]]).tolist() == [[1.0, 2.0], [2.0, 4.0]], 'cov 2d rows' +assert np.corrcoef([1, 2, 3]) == 1.0, 'corrcoef 1d' +assert np.corrcoef([[1, 2, 3], [2, 4, 6]]).tolist() == [[1.0, 1.0], [1.0, 1.0]], 'corrcoef 2d rows' + +unique_input = np.array([3, 1, 3, 2, 1, 3]) +assert np.sort(np.unique_values(unique_input)).tolist() == [1, 2, 3], 'unique_values sorted contents' +unique_counts = np.unique_counts(unique_input) +assert unique_counts.values.tolist() == [1, 2, 3], 'unique_counts values' +assert unique_counts.counts.tolist() == [2, 1, 3], 'unique_counts counts' +unique_inverse = np.unique_inverse(unique_input) +assert unique_inverse.values.tolist() == [1, 2, 3], 'unique_inverse values' +assert unique_inverse.inverse_indices.tolist() == [2, 0, 2, 1, 0, 2], 'unique_inverse indices' +unique_all = np.unique_all(unique_input) +assert unique_all.values.tolist() == [1, 2, 3], 'unique_all values' +assert unique_all.indices.tolist() == [1, 3, 0], 'unique_all first indices' +assert unique_all.inverse_indices.tolist() == [2, 0, 2, 1, 0, 2], 'unique_all inverse indices' +assert unique_all.counts.tolist() == [2, 1, 3], 'unique_all counts' + +partition_input = np.array([3, 1, 2]) +assert np.partition(partition_input, 1).tolist() == [1, 2, 3], 'partition 1d deterministic sorted subset' +assert np.argpartition(partition_input, 1).tolist() == [1, 2, 0], 'argpartition 1d deterministic argsort subset' +partition_neg = np.partition(partition_input, -1).tolist() +assert partition_neg[-1] == 3, 'partition negative kth places max' +assert sorted(partition_neg) == [1, 2, 3], 'partition negative kth preserves values' +lex_row = np.lexsort(([2, 1, 2, 1], [0, 1, 0, 0])) +assert lex_row.tolist() == [3, 0, 2, 1], 'lexsort two keys' +assert np.lexsort(([3, 1, 2],)).tolist() == [1, 2, 0], 'lexsort one key' + + +# === integer and boolean bitwise helpers === +assert np.bitwise_and(6, 3) == 2, 'bitwise_and scalar' +assert np.bitwise_and(True, False) == False, 'bitwise_and bool scalar' +assert np.bitwise_and([1, 2, 3], [3, 1, 2]).tolist() == [1, 0, 2], 'bitwise_and lists' +bool_and = np.bitwise_and(np.array([True, False]), True) +assert bool_and.tolist() == [True, False], 'bitwise_and bool array' +assert str(bool_and.dtype) == 'bool', 'bitwise_and bool dtype' + +assert np.bitwise_or([1, 2, 4], 1).tolist() == [1, 3, 5], 'bitwise_or list scalar' +assert np.bitwise_xor(7, [1, 2, 4]).tolist() == [6, 5, 3], 'bitwise_xor scalar list' +assert np.bitwise_not([0, 1, -2]).tolist() == [-1, -2, 1], 'bitwise_not list' +assert np.bitwise_invert([0, 1]).tolist() == [-1, -2], 'bitwise_invert alias' +inverted_bools = np.invert(np.array([True, False])) +assert inverted_bools.tolist() == [False, True], 'invert bool array' +assert str(inverted_bools.dtype) == 'bool', 'invert bool dtype' + +assert np.left_shift([1, 2, 3], 2).tolist() == [4, 8, 12], 'left_shift list scalar' +assert np.bitwise_left_shift(1, 3) == 8, 'bitwise_left_shift alias' +assert np.right_shift([8, -8], 1).tolist() == [4, -4], 'right_shift list scalar' +assert np.bitwise_right_shift(-8, 1) == -4, 'bitwise_right_shift alias' +assert np.bitwise_count(7) == 3, 'bitwise_count scalar' +assert np.bitwise_count([-1, -2, -3]).tolist() == [1, 1, 2], 'bitwise_count negative list' + +packed = np.packbits([1, 0, 1, 1, 0, 0, 1, 0]) +assert packed.tolist() == [178], 'packbits byte' +assert np.unpackbits(packed).tolist() == [1, 0, 1, 1, 0, 0, 1, 0], 'unpackbits roundtrip' + + +# === integer representation helpers === +assert np.base_repr(10) == '1010', 'base_repr default base' +assert np.base_repr(-10) == '-1010', 'base_repr negative' +assert np.base_repr(10, 16) == 'A', 'base_repr hex' +assert np.base_repr(10, 2, 5) == '000001010', 'base_repr padding' +assert np.base_repr(0, 2, 5) == '00000', 'base_repr zero padding' + +assert np.binary_repr(3) == '11', 'binary_repr positive' +assert np.binary_repr(-3) == '-11', 'binary_repr negative no width' +assert np.binary_repr(3, 5) == '00011', 'binary_repr positive width' +assert np.binary_repr(-3, 5) == '11101', 'binary_repr negative width' + + +# === finite conversion, predicates, and simple 1d helpers === +assert np.isposinf([np.inf, -np.inf, 1.0]).tolist() == [True, False, False], 'isposinf values' +assert np.isneginf([np.inf, -np.inf, 1.0]).tolist() == [False, True, False], 'isneginf values' +assert np.asarray_chkfinite([1, 2, 3]).tolist() == [1, 2, 3], 'asarray_chkfinite finite' +try: + np.asarray_chkfinite([1.0, np.inf]) + assert False, 'expected asarray_chkfinite to reject infinity' +except ValueError as exc: + assert str(exc) == 'array must not contain infs or NaNs', 'asarray_chkfinite error message' + +assert np.ascontiguousarray([1, 2]).tolist() == [1, 2], 'ascontiguousarray list' +assert np.asfortranarray([1, 2]).tolist() == [1, 2], 'asfortranarray list' +assert np.require([1, 2]).tolist() == [1, 2], 'require list' +assert np.real_if_close([1, 2]).tolist() == [1, 2], 'real_if_close list' + +assert np.array_equiv([1, 2], [1, 2]) == True, 'array_equiv equal arrays' +assert np.array_equiv([1, 1], 1) == True, 'array_equiv scalar broadcast' +assert np.array_equiv([1, 2], 1) == False, 'array_equiv scalar mismatch' +assert np.ediff1d([[1, 2], [4, 7]]).tolist() == [1, 2, 3], 'ediff1d flattened' +assert np.trim_zeros([0, 0, 1, 0, 2, 0]).tolist() == [1, 0, 2], 'trim_zeros both' +assert np.trim_zeros([0, 0, 1, 0, 2, 0], 'f').tolist() == [1, 0, 2, 0], 'trim_zeros front' +assert np.unwrap([0.0, 1.0, 2.0]).tolist() == [0.0, 1.0, 2.0], 'unwrap no jump' +unwrapped_pos = np.unwrap([0.0, 3.5, 6.0]).tolist() +assert abs(unwrapped_pos[1] + 2.7831853071795862) < 1e-12, 'unwrap positive jump first' +assert abs(unwrapped_pos[2] + 0.28318530717958623) < 1e-12, 'unwrap positive jump second' +unwrapped_neg = np.unwrap([0.0, -3.5, -6.0]).tolist() +assert abs(unwrapped_neg[1] - 2.7831853071795862) < 1e-12, 'unwrap negative jump first' +assert abs(unwrapped_neg[2] - 0.28318530717958623) < 1e-12, 'unwrap negative jump second' + + +# === real-only aliases and introspection helpers === +real_values = np.array([-2, 0, 3]) +assert np.conj(-5) == -5, 'conj scalar keeps real value' +assert np.conj(real_values).tolist() == [-2, 0, 3], 'conj array keeps real values' +assert np.conjugate([1.5, -2.5]).tolist() == [1.5, -2.5], 'conjugate list converts to array' +assert np.real(-4.5) == -4.5, 'real scalar keeps value' +assert np.real(real_values).tolist() == [-2, 0, 3], 'real array keeps values' +assert np.imag(7) == 0, 'imag int scalar is zero' +assert np.imag(1.25) == 0.0, 'imag float scalar is zero' +assert np.imag(real_values).tolist() == [0, 0, 0], 'imag int array is zeros' +assert np.imag([1.5, -2.5]).tolist() == [0.0, 0.0], 'imag float list is zeros' +assert np.isreal(3) == True, 'isreal scalar true' +assert np.isreal(real_values).tolist() == [True, True, True], 'isreal array all true' +assert np.iscomplex(3) == False, 'iscomplex scalar false' +assert np.iscomplex(real_values).tolist() == [False, False, False], 'iscomplex array all false' +assert np.isrealobj(real_values) == True, 'isrealobj array true' +assert np.isrealobj('text') == True, 'isrealobj string true' +assert np.iscomplexobj(real_values) == False, 'iscomplexobj array false' +assert np.iscomplexobj('text') == False, 'iscomplexobj string false' +assert np.isscalar(1) == True, 'isscalar int true' +assert np.isscalar('text') == True, 'isscalar string true' +assert np.isscalar(np.array([1])) == False, 'isscalar ndarray false' +assert np.isscalar([1]) == False, 'isscalar list false' +assert np.iterable([1, 2]) == True, 'iterable list true' +assert np.iterable((1, 2)) == True, 'iterable tuple true' +assert np.iterable('text') == True, 'iterable string true' +assert np.iterable(np.array([1, 2])) == True, 'iterable ndarray true' +assert np.iterable(1) == False, 'iterable int false' + + +# === dtype aliases and scalar constants === +assert np.array([1.2, -2.8]).astype(np.int_).tolist() == [1, -2], 'int_ dtype alias' +assert np.array([1.2, -2.8]).astype(np.intc).tolist() == [1, -2], 'intc dtype alias' +assert np.array([1.2, 2.8]).astype(np.uint8).tolist() == [1, 2], 'uint8 dtype alias' +assert np.array([1, 0, -2]).astype(np.bool).tolist() == [True, False, True], 'bool dtype alias' +assert np.array([1, 2]).astype(np.double).tolist() == [1.0, 2.0], 'double dtype alias' +assert np.array([1, 2]).astype(np.float16).tolist() == [1.0, 2.0], 'float16 dtype alias' +assert np.array([1, 2]).astype(np.longdouble).tolist() == [1.0, 2.0], 'longdouble dtype alias' +assert np.array([1, 2]).astype('uint8').tolist() == [1, 2], 'astype string uint alias' +assert np.dtype('float64') == np.float64, 'dtype normalizes float64 string' +assert np.dtype(np.float32) == np.float32, 'dtype preserves compact float32 marker' +assert np.dtype(int) == np.int64, 'dtype normalizes Python int type' +assert np.dtype(float) == np.float64, 'dtype normalizes Python float type' +assert np.dtype(bool) == np.bool_, 'dtype normalizes Python bool type' +assert str(np.dtype('int64')) == 'int64', 'dtype string display' +assert np.dtype(np.complex64) == np.complex64, 'dtype preserves complex64 metadata marker' +assert np.dtype(np.csingle) == np.complex64, 'dtype normalizes csingle metadata alias' +assert np.dtype(np.cdouble) == np.complex128, 'dtype normalizes cdouble metadata alias' +assert np.dtype(np.clongdouble) == np.complex128, 'dtype normalizes clongdouble metadata alias' +assert np.dtype(np.str_) == np.str_, 'dtype preserves str_ metadata marker' +assert np.dtype(np.bytes_) == np.bytes_, 'dtype preserves bytes_ metadata marker' +assert np.dtype(np.void) == np.void, 'dtype preserves void metadata marker' +assert np.dtype(np.object_) == np.object_, 'dtype preserves object_ metadata marker' +assert np.dtype(np.datetime64) == np.datetime64, 'dtype preserves datetime64 metadata marker' +assert np.dtype(np.timedelta64) == np.timedelta64, 'dtype preserves timedelta64 metadata marker' +try: + np.astype(np.array([1]), np.complex64) + assert False, 'expected module astype complex64 to fail' +except TypeError as exc: + assert str(exc) == 'numpy.astype() unsupported dtype: complex64', 'module astype rejects complex storage' +try: + np.array([1]).astype(np.str_) + assert False, 'expected ndarray astype str_ to fail' +except TypeError as exc: + assert str(exc) == 'unsupported dtype: str_', 'ndarray astype rejects string storage' +assert np.astype(np.array([1.2, -2.8]), np.int64).tolist() == [1, -2], 'module astype int dtype' +assert np.astype(np.array([1, 0, -2]), bool).tolist() == [True, False, True], 'module astype bool dtype' +assert np.astype(np.array([1, 2]), np.float64).tolist() == [1.0, 2.0], 'module astype float dtype' +assert isinstance(np.array([1, 2]), np.ndarray) == True, 'ndarray type object matches arrays' +assert np.ndarray.__name__ == 'ndarray', 'ndarray type object name' +assert repr(np.ndarray) == "", 'ndarray type object repr' +flat_marker = np.array([[1, 2], [3, 4]]).flat +assert isinstance(flat_marker, np.flatiter) == True, 'flatiter marker matches ndarray flat result' +assert isinstance(np.array([1, 2, 3]), np.flatiter) == False, 'plain ndarray is not flatiter' +assert np.flatiter.__name__ == 'flatiter', 'flatiter type object name' +assert type(flat_marker).__name__ == 'flatiter', 'flatiter result type name' +assert flat_marker.tolist() == [1, 2, 3, 4], 'flatiter marker backed by flat data' +assert list(flat_marker) == [1, 2, 3, 4], 'flatiter iteration follows flat data' +assert isinstance(np.add, np.ufunc) == True, 'add is public ufunc-like callable' +assert isinstance(np.sin, np.ufunc) == True, 'sin is public ufunc-like callable' +assert isinstance(np.array, np.ufunc) == False, 'array constructor is not a ufunc' +assert type(np.add).__name__ == 'ufunc', 'ufunc callable type name' +assert np.ufunc.__name__ == 'ufunc', 'ufunc type object name' +assert np.can_cast(np.int8, np.int16) == True, 'can_cast integer aliases' +assert np.can_cast(np.float64, np.int64) == False, 'can_cast float to int is unsafe' +assert np.promote_types(np.int8, np.float32) == np.float32, 'promote_types int float32' +assert np.result_type(np.array([1, 2]), 1.5) == np.float64, 'result_type array scalar' +assert np.common_type(np.array([1, 2]), np.array([3, 4])) == np.float64, 'common_type int arrays' +assert np.min_scalar_type(-3) == np.int8, 'min_scalar_type negative int alias' +assert np.min_scalar_type(3) == np.uint8, 'min_scalar_type positive int alias' +assert np.issubdtype(np.int64, np.integer) == True, 'issubdtype integer category' +assert np.issubdtype(np.float64, np.floating) == True, 'issubdtype floating category' +assert np.issubdtype(np.float64, np.inexact) == True, 'issubdtype inexact category' +assert np.issubdtype(np.int64, np.inexact) == False, 'issubdtype int not inexact' +assert np.issubdtype(np.bool_, np.integer) == False, 'issubdtype bool not integer' +assert np.issubdtype(np.complex64, np.complexfloating) == True, 'complex64 is complexfloating' +assert np.issubdtype(np.complex128, np.inexact) == True, 'complex128 is inexact' +assert np.issubdtype(np.complex128, np.number) == True, 'complex128 is number' +assert np.issubdtype(np.str_, np.character) == True, 'str_ is character' +assert np.issubdtype(np.bytes_, np.character) == True, 'bytes_ is character' +assert np.issubdtype(np.void, np.flexible) == True, 'void is flexible' +assert np.issubdtype(np.object_, np.generic) == True, 'object_ is generic' +assert np.issubdtype(np.datetime64, np.generic) == True, 'datetime64 is generic' +assert np.issubdtype(np.timedelta64, np.number) == True, 'timedelta64 is number' +assert np.isdtype(np.float64, 'real floating') == True, 'isdtype real floating' +assert np.isdtype(np.int64, 'integral') == True, 'isdtype integral' +assert np.isdtype(np.int64, 'numeric') == True, 'isdtype numeric int' +assert np.isdtype(np.bool_, 'numeric') == False, 'isdtype numeric excludes bool' +assert np.isdtype(np.complex64, 'complex floating') == True, 'isdtype complex floating' +assert np.isdtype(np.complex64, 'numeric') == True, 'isdtype numeric complex' +assert np.isdtype(np.str_, 'numeric') == False, 'isdtype string not numeric' +assert np.can_cast(np.complex64, np.complex128) == True, 'can_cast complex64 to complex128' +assert np.can_cast(np.complex64, np.float64) == False, 'can_cast complex to float is unsafe' +assert np.can_cast(np.str_, np.object_) == True, 'can_cast str to object' +assert np.can_cast(np.void, np.str_) == False, 'can_cast void to str false' +assert np.promote_types(np.complex64, np.float32) == np.complex64, 'promote complex64 float32' +assert np.promote_types(np.complex64, np.int64) == np.complex128, 'promote complex64 int64' +assert np.result_type(np.complex64, np.float32) == np.complex64, 'result_type complex64 float32' +assert np.result_type(np.object_, np.int64) == np.object_, 'result_type object int64' +assert isinstance(1, np.ScalarType) == True, 'ScalarType includes int' +assert isinstance(1.5, np.ScalarType) == True, 'ScalarType includes float' +assert isinstance('x', np.ScalarType) == True, 'ScalarType includes str' +assert isinstance([], np.ScalarType) == False, 'ScalarType excludes list' +float64_info = np.finfo(np.float64) +assert float64_info.bits == 64, 'finfo float64 bits' +assert float64_info.eps == 2.220446049250313e-16, 'finfo float64 eps' +assert float64_info.tiny == 2.2250738585072014e-308, 'finfo float64 tiny' +assert str(float64_info.dtype) == 'float64', 'finfo float64 dtype' +float32_info = np.finfo('float32') +assert float32_info.bits == 32, 'finfo float32 bits' +assert float32_info.precision == 6, 'finfo float32 precision' +assert float32_info.resolution == 1e-06, 'finfo float32 resolution' +assert np.finfo('float16').tiny == 6.103515625e-05, 'finfo float16 tiny' +int16_info = np.iinfo('int16') +assert int16_info.bits == 16, 'iinfo int16 bits' +assert int16_info.min == -32768, 'iinfo int16 min' +assert int16_info.max == 32767, 'iinfo int16 max' +assert str(int16_info.dtype) == 'int16', 'iinfo int16 dtype' +assert np.iinfo('uint8').max == 255, 'iinfo uint8 max' +assert np.iinfo('uint64').max == 18446744073709551615, 'iinfo uint64 max' +assert np.iinfo(1).bits == 64, 'iinfo scalar int bits' + + +# === float formatting helpers === +assert np.format_float_positional(1.0) == '1.', 'format_float_positional whole float' +assert np.format_float_positional(1.2, precision=2, unique=False) == '1.20', 'format_float_positional fixed precision' +assert np.format_float_positional(1.2, sign=True) == '+1.2', 'format_float_positional sign' +assert np.format_float_positional(1.2, min_digits=4) == '1.2000', 'format_float_positional min digits' +assert np.format_float_scientific(1.0) == '1.e+00', 'format_float_scientific whole float' +assert np.format_float_scientific(1000.0) == '1.e+03', 'format_float_scientific positive exponent' +assert np.format_float_scientific(0.0001234) == '1.234e-04', 'format_float_scientific negative exponent' +assert np.format_float_scientific(1.2, precision=2, unique=False) == '1.20e+00', ( + 'format_float_scientific fixed precision' +) + + +def fromfunction_sum(row, col): + return row + col + + +def fromfunction_linear(row, col): + return row * 10 + col + + +def fromfunction_scale(index, scale=1): + return index * scale + + +assert np.fromfunction(fromfunction_sum, (2, 3)).tolist() == [ + [0.0, 1.0, 2.0], + [1.0, 2.0, 3.0], +], 'fromfunction float coordinate sum' +assert np.fromfunction(fromfunction_linear, (2, 3), dtype=int).tolist() == [ + [0, 1, 2], + [10, 11, 12], +], 'fromfunction integer coordinate expression' +assert np.fromfunction(fromfunction_scale, (4,), dtype=int, scale=3).tolist() == [ + 0, + 3, + 6, + 9, +], 'fromfunction forwards callable kwargs' +assert np.fromfunction(fromfunction_scale, (0,), dtype=int).tolist() == [], 'fromfunction empty shape' +assert np.fromiter([1, 2, 3], int).tolist() == [1, 2, 3], 'fromiter int list' +assert np.fromiter([1, 2.5, 3], float).tolist() == [1.0, 2.5, 3.0], 'fromiter float list' +assert np.fromiter([1, 2, 3], np.int64, count=2).tolist() == [1, 2], 'fromiter count keyword' +assert np.fromiter([True, False], bool).tolist() == [True, False], 'fromiter bool dtype' +assert np.fromiter([1, 2], None).tolist() == [1.0, 2.0], 'fromiter dtype none default' +assert np.fromstring('1 2 3', dtype=int, sep=' ').tolist() == [1, 2, 3], 'fromstring int whitespace' +assert np.fromstring('1, 2, 3', dtype=float, sep=',').tolist() == [ + 1.0, + 2.0, + 3.0, +], 'fromstring float comma' +assert np.fromstring('1,2,3', dtype=np.int64, count=2, sep=',').tolist() == [1, 2], 'fromstring count' +assert np.fromstring('1 0 2', dtype=bool, sep=' ').tolist() == [True, False, True], 'fromstring bool' +assert np.mintypecode(['i', 'f']) == 'f', 'mintypecode int float' +assert np.typename('i') == 'integer', 'typename integer code' +assert np.typename('d') == 'double precision', 'typename double code' +assert np.typecodes['Float'] == 'efdg', 'typecodes float family' +assert np.sctypeDict['float64'] == np.float64, 'sctypeDict float64 alias' +assert np.sctypeDict['int32'] == np.int32, 'sctypeDict int32 alias' +assert np.sctypeDict['bool'] == np.bool_, 'sctypeDict bool alias' +assert np.sctypeDict['complex64'] == np.complex64, 'sctypeDict complex64 alias' +assert np.sctypeDict['str_'] == np.str_, 'sctypeDict str_ alias' +assert np.sctypeDict['object_'] == np.object_, 'sctypeDict object_ alias' +assert np.info is not None, 'info export present' +err_policy = np.geterr() +assert err_policy['divide'] == 'warn', 'geterr divide policy' +assert err_policy['under'] == 'ignore', 'geterr under policy' +seterr_previous = np.seterr(divide='ignore') +assert seterr_previous['divide'] == 'warn', 'seterr returns previous policy' +np.seterr(divide='warn') +print_options = np.get_printoptions() +assert print_options['threshold'] == 1000, 'get_printoptions threshold' +assert print_options['precision'] == 8, 'get_printoptions precision' +assert np.set_printoptions(threshold=10) is None, 'set_printoptions return' +np.set_printoptions(threshold=1000) +assert np.getbufsize() == 8192, 'getbufsize default' +assert np.setbufsize(8192) == 8192, 'setbufsize previous size' +assert np.errstate(divide='ignore') is not None, 'errstate placeholder' +assert np.printoptions(threshold=10) is not None, 'printoptions placeholder' +assert np.geterrcall() is None, 'geterrcall default' +assert np.seterrcall(None) is None, 'seterrcall previous callback' +assert np.geterrcall() is None, 'geterrcall after reset' +assert np.show_runtime is not None, 'show_runtime export present' +assert np.test is not None, 'test export present' +display_int = np.array([1, 2, 3]) +assert np.array2string(display_int) == '[1 2 3]', 'array2string int vector' +assert np.array_str(display_int) == '[1 2 3]', 'array_str int vector' +assert np.array_repr(display_int) == 'array([1, 2, 3])', 'array_repr int vector' +display_float = np.array([1.0, 2.0, 3.0]) +assert np.array2string(display_float) == '[1. 2. 3.]', 'array2string float vector' +assert np.array_str(display_float) == '[1. 2. 3.]', 'array_str float vector' +display_bool = np.array([True, False]) +assert np.array2string(display_bool) == '[ True False]', 'array2string bool vector' +display_matrix = np.array([[1, 2], [3, 4]]) +assert np.array2string(display_matrix) == '[[1 2]\n [3 4]]', 'array2string matrix' +assert np.array_str(display_matrix) == '[[1 2]\n [3 4]]', 'array_str matrix' +display_empty = np.array([]) +assert np.array2string(display_empty) == '[]', 'array2string empty' +assert np.array_str(display_empty) == '[]', 'array_str empty' +assert np.array_repr(display_empty) == 'array([], dtype=float64)', 'array_repr empty' +choose_idx = np.array([0, 1, 0, 1]) +assert np.choose(choose_idx, [[10, 20, 30, 40], [50, 60, 70, 80]]).tolist() == [ + 10, + 60, + 30, + 80, +], 'choose vector' +assert np.resize([1, 2, 3], (2, 4)).tolist() == [[1, 2, 3, 1], [2, 3, 1, 2]], 'resize repeats' +take_axis_arr = np.array([[10, 20, 30], [40, 50, 60]]) +take_axis_idx = np.array([[2, 1], [0, 2]]) +assert np.take_along_axis(take_axis_arr, take_axis_idx, axis=1).tolist() == [ + [30, 20], + [40, 60], +], 'take_along_axis axis 1' +assert np.take_along_axis(take_axis_arr, np.array([[1, 0, 1]]), axis=0).tolist() == [[40, 20, 60]], ( + 'take_along_axis axis 0' +) +put_axis_arr = np.array([[10, 20, 30], [40, 50, 60]]) +assert np.put_along_axis(put_axis_arr, take_axis_idx, [[99, 88], [77, 66]], axis=1) is None, 'put_along_axis return' +assert put_axis_arr.tolist() == [[10, 88, 99], [77, 50, 66]], 'put_along_axis axis 1' + + +# === axis and subarray helpers === +axis_arr = np.array([[1, 2, 3], [4, 5, 6]]) + + +def apply_row_summary(row): + return np.array([row[0], row[-1], row.sum()]) + + +def apply_column_product(col): + return col[0] * col[-1] + + +def apply_offset(row, offset=0): + return row + offset + + +assert np.apply_along_axis(apply_row_summary, 1, axis_arr).tolist() == [ + [1, 3, 6], + [4, 6, 15], +], 'apply_along_axis inserts vector result at iterated axis' +assert np.apply_along_axis(apply_column_product, 0, axis_arr).tolist() == [4, 10, 18], 'apply_along_axis scalar result' +assert np.apply_along_axis(apply_offset, -1, axis_arr, offset=10).tolist() == [ + [11, 12, 13], + [14, 15, 16], +], 'apply_along_axis forwards callable kwargs' +assert np.apply_over_axes(np.sum, axis_arr, [0]).tolist() == [[5, 7, 9]], 'apply_over_axes sum axis 0' +assert np.apply_over_axes(np.sum, axis_arr, [0, 1]).tolist() == [[21]], 'apply_over_axes sum axes 0 and 1' +assert np.apply_over_axes(np.prod, axis_arr, (1,)).tolist() == [[6], [120]], 'apply_over_axes prod axis 1' +piecewise_x = np.array([0, 1, 2, 3]) +assert np.piecewise( + piecewise_x, + [np.array([True, False, False, False]), np.array([False, True, True, False])], + [10, lambda x: x + 20, 99], +).tolist() == [10, 21, 22, 99], 'piecewise scalar callable and default' +assert np.piecewise(axis_arr, [axis_arr > 3], [lambda x: x * 10, -1]).tolist() == [ + [-1, -1, -1], + [40, 50, 60], +], 'piecewise broadcasts condition over array' +assert np.piecewise( + np.array([1, 2, 3]), + [np.array([True, True, False]), np.array([True, False, True])], + [10, 20, 0], +).tolist() == [20, 10, 20], 'piecewise later conditions overwrite earlier matches' +assert np.pad(axis_arr, 1, mode='constant', constant_values=9).tolist() == [ + [9, 9, 9, 9, 9], + [9, 1, 2, 3, 9], + [9, 4, 5, 6, 9], + [9, 9, 9, 9, 9], +], 'pad constant scalar width and value' +assert np.pad(axis_arr, ((1, 0), (2, 1)), mode='constant', constant_values=((7, 8), (9, 10))).tolist() == [ + [9, 9, 7, 7, 7, 10], + [9, 9, 1, 2, 3, 10], + [9, 9, 4, 5, 6, 10], +], 'pad constant per-axis widths and values' +assert np.pad(axis_arr, ((1, 1), (2, 1)), mode='edge').tolist() == [ + [1, 1, 1, 2, 3, 3], + [1, 1, 1, 2, 3, 3], + [4, 4, 4, 5, 6, 6], + [4, 4, 4, 5, 6, 6], +], 'pad edge mode' +assert np.pad([1, 2, 3], 2, mode='reflect').tolist() == [3, 2, 1, 2, 3, 2, 1], 'pad reflect mode' +assert np.nanquantile([1.0, float('nan'), 3.0], 0.5) == 2.0, 'nanquantile median' +assert np.nanpercentile([1.0, float('nan'), 3.0], 50) == 2.0, 'nanpercentile median' +hist_counts, hist_edges = np.histogram([0, 1, 1, 2, 3], bins=3) +assert hist_counts.tolist() == [1, 2, 2], 'histogram counts' +assert hist_edges.tolist() == [0.0, 1.0, 2.0, 3.0], 'histogram edges' +assert np.histogram_bin_edges([0, 1, 1, 2, 3], bins=3).tolist() == [ + 0.0, + 1.0, + 2.0, + 3.0, +], 'histogram_bin_edges' +hist2d_counts, hist2d_xedges, hist2d_yedges = np.histogram2d([0, 1, 1, 2], [0, 0, 1, 1], bins=2) +assert hist2d_counts.tolist() == [[1.0, 0.0], [1.0, 2.0]], 'histogram2d counts' +assert hist2d_xedges.tolist() == [0.0, 1.0, 2.0], 'histogram2d xedges' +assert hist2d_yedges.tolist() == [0.0, 0.5, 1.0], 'histogram2d yedges' +histdd_counts, histdd_edges = np.histogramdd(np.array([[0, 0], [1, 0], [1, 1], [2, 1]]), bins=2) +assert histdd_counts.tolist() == [[1.0, 0.0], [1.0, 2.0]], 'histogramdd counts' +assert [edge.tolist() for edge in histdd_edges] == [ + [0.0, 1.0, 2.0], + [0.0, 0.5, 1.0], +], 'histogramdd edges' +assert np.little_endian == True, 'little_endian constant' +assert abs(np.euler_gamma - 0.5772156649015329) < 1e-15, 'euler_gamma constant' diff --git a/crates/monty/test_cases/numpy__arithmetic.py b/crates/monty/test_cases/numpy__arithmetic.py new file mode 100644 index 000000000..b0001623c --- /dev/null +++ b/crates/monty/test_cases/numpy__arithmetic.py @@ -0,0 +1,76 @@ +# skip-cpython +# === Element-wise arithmetic === +import numpy as np + +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) + +# === Addition === +c = a + b +assert c[0] == 5, 'add first' +assert c[1] == 7, 'add second' +assert c[2] == 9, 'add third' + +# === Subtraction === +d = b - a +assert d[0] == 3, 'sub first' +assert d[1] == 3, 'sub second' +assert d[2] == 3, 'sub third' + +# === Multiplication === +e = a * b +assert e[0] == 4, 'mul first' +assert e[1] == 10, 'mul second' +assert e[2] == 18, 'mul third' + +# === Division === +f = b / a +assert f[0] == 4.0, 'div first' +assert f[1] == 2.5, 'div second' +assert f[2] == 2.0, 'div third' + +# === Scalar operations === +g = a * 2 +assert g[0] == 2, 'scalar mul first' +assert g[1] == 4, 'scalar mul second' +assert g[2] == 6, 'scalar mul third' + +h = a + 10 +assert h[0] == 11, 'scalar add first' +assert h[1] == 12, 'scalar add second' +assert h[2] == 13, 'scalar add third' + +# === Power === +p = a**2 +assert p[0] == 1, 'pow first' +assert p[1] == 4, 'pow second' +assert p[2] == 9, 'pow third' + +# === Floor division === +fd = np.array([7, 8, 9]) // np.array([2, 3, 4]) +assert fd[0] == 3, 'floordiv first' +assert fd[1] == 2, 'floordiv second' +assert fd[2] == 2, 'floordiv third' + +# === Modulo === +m = np.array([7, 8, 9]) % np.array([2, 3, 4]) +assert m[0] == 1, 'mod first' +assert m[1] == 2, 'mod second' +assert m[2] == 1, 'mod third' + +# === Negation === +neg = -a +assert neg[0] == -1, 'neg first' +assert neg[1] == -2, 'neg second' +assert neg[2] == -3, 'neg third' + +# === Scalar subtraction === +sub_scalar = a - 1 +assert sub_scalar[0] == 0, 'scalar sub first' +assert sub_scalar[1] == 1, 'scalar sub second' +assert sub_scalar[2] == 2, 'scalar sub third' + +# === In-place addition (+=) === +iadd = np.array([1, 2, 3]) +iadd += np.array([10, 20, 30]) +assert iadd.tolist() == [11, 22, 33], 'iadd arrays' diff --git a/crates/monty/test_cases/numpy__array_creation.py b/crates/monty/test_cases/numpy__array_creation.py new file mode 100644 index 000000000..546405ae8 --- /dev/null +++ b/crates/monty/test_cases/numpy__array_creation.py @@ -0,0 +1,82 @@ +# skip-cpython +# === np.array from list === +import numpy as np + +a = np.array([1, 2, 3]) +assert len(a) == 3, 'array length' +assert a[0] == 1, 'first element' +assert a[1] == 2, 'second element' +assert a[2] == 3, 'third element' + +# === np.array from nested list (2D) === +b = np.array([[1, 2], [3, 4]]) +assert b[0][0] == 1, '2d first element' +assert b[0][1] == 2, '2d second element' +assert b[1][0] == 3, '2d third element' +assert b[1][1] == 4, '2d fourth element' + +# === np.zeros === +z = np.zeros(3) +assert len(z) == 3, 'zeros length' +assert z[0] == 0.0, 'zeros first' +assert z[1] == 0.0, 'zeros second' +assert z[2] == 0.0, 'zeros third' + +# === np.ones === +o = np.ones(4) +assert len(o) == 4, 'ones length' +assert o[0] == 1.0, 'ones first' +assert o[3] == 1.0, 'ones last' + +# === np.arange === +r = np.arange(5) +assert len(r) == 5, 'arange length' +assert r[0] == 0, 'arange first' +assert r[4] == 4, 'arange last' + +r2 = np.arange(2, 7) +assert len(r2) == 5, 'arange start stop length' +assert r2[0] == 2, 'arange start' +assert r2[4] == 6, 'arange last' + +r3 = np.arange(0, 10, 2) +assert len(r3) == 5, 'arange step length' +assert r3[0] == 0, 'arange step first' +assert r3[4] == 8, 'arange step last' + +# === np.linspace === +ls = np.linspace(0, 1, 5) +assert len(ls) == 5, 'linspace length' +assert ls[0] == 0.0, 'linspace first' +assert ls[4] == 1.0, 'linspace last' + +# === shape attribute === +a1 = np.array([1, 2, 3]) +assert a1.shape == (3,), '1d shape' + +a2 = np.array([[1, 2], [3, 4], [5, 6]]) +assert a2.shape == (3, 2), '2d shape' + +# === dtype attribute === +int_arr = np.array([1, 2, 3]) +assert str(int_arr.dtype) == 'int64', 'int dtype' + +float_arr = np.array([1.0, 2.0, 3.0]) +assert str(float_arr.dtype) == 'float64', 'float dtype' + +# === np.linspace edge cases === +ls2 = np.linspace(0, 10, 3) +assert ls2.tolist() == [0.0, 5.0, 10.0], 'linspace 3 points' + +ls_single = np.linspace(5, 5, 1) +assert ls_single.tolist() == [5.0], 'linspace single point' + +# === np.arange with float step === +r_float = np.arange(0, 1, 0.5) +assert len(r_float) == 2, 'arange float step length' +assert r_float[0] == 0.0, 'arange float step first' +assert r_float[1] == 0.5, 'arange float step second' + +# === Float array from mixed list === +mixed = np.array([1, 2.5, 3]) +assert str(mixed.dtype) == 'float64', 'mixed promotes to float' diff --git a/crates/monty/test_cases/numpy__broadcasting.py b/crates/monty/test_cases/numpy__broadcasting.py new file mode 100644 index 000000000..56ce19f4d --- /dev/null +++ b/crates/monty/test_cases/numpy__broadcasting.py @@ -0,0 +1,80 @@ +# skip-cpython +import numpy as np + + +# === Broadcast shape helpers === +assert np.broadcast_shapes((3, 1), (1, 4)) == (3, 4), 'broadcast_shapes should combine singleton axes' +assert np.broadcast_to(np.array([1, 2, 3]), (2, 3)).tolist() == [[1, 2, 3], [1, 2, 3]], ( + 'broadcast_to should materialize leading singleton axes' +) +assert np.broadcast_to(5, (2, 2)).tolist() == [[5, 5], [5, 5]], 'broadcast_to should accept scalar input' + +broadcasted = np.broadcast_arrays(np.array([[1], [2]]), np.array([10, 20, 30])) +assert len(broadcasted) == 2, 'broadcast_arrays should return one result per input' +assert [array.tolist() for array in broadcasted] == [ + [[1, 1, 1], [2, 2, 2]], + [[10, 20, 30], [10, 20, 30]], +], 'broadcast_arrays should materialize all inputs to the shared shape' +assert list(np.broadcast(np.array([1, 2]), 10)) == [(1, 10), (2, 10)], ( + 'broadcast subset should provide NumPy-compatible iteration payloads' +) + + +# === Ufunc and operator broadcasting === +column = np.array([[1], [2]]) +row = np.array([10, 20, 30]) +assert (column + row).tolist() == [[11, 21, 31], [12, 22, 32]], 'ndarray operators should broadcast' +assert np.add(column, row).tolist() == [[11, 21, 31], [12, 22, 32]], 'np.add should broadcast arrays' +assert np.maximum(column, row).tolist() == [[10, 20, 30], [10, 20, 30]], 'pairwise math should broadcast' +assert (column < np.array([2, 2, 2])).tolist() == [[True, True, True], [False, False, False]], ( + 'comparisons should broadcast' +) +assert np.logical_and(np.array([[True], [False]]), np.array([True, False, True])).tolist() == [ + [True, False, True], + [False, False, False], +], 'logical ufuncs should broadcast' + + +# === Selection and testing helpers === +assert np.where(np.array([[True], [False]]), np.array([1, 2, 3]), 0).tolist() == [ + [1, 2, 3], + [0, 0, 0], +], 'where should broadcast condition and choices' +assert np.isclose(np.array([[1.0], [2.0]]), np.array([1.0, 3.0, 2.0])).tolist() == [ + [True, False, False], + [False, False, True], +], 'isclose should broadcast and preserve result shape' +assert not np.allclose(np.array([[1.0], [2.0]]), np.array([1.0, 2.0, 2.0])), ( + 'allclose should compare the broadcasted result' +) +assert np.array_equiv(np.array([[1], [1]]), np.array([1, 1, 1])), 'array_equiv should use broadcast equality' + + +# === Integer and bitwise broadcasting === +assert np.gcd(np.array([[6], [10]]), np.array([4, 5, 6])).tolist() == [ + [2, 1, 6], + [2, 5, 2], +], 'integer ufuncs should broadcast' +assert np.bitwise_and(np.array([[True], [False]]), np.array([True, False])).tolist() == [ + [True, False], + [False, False], +], 'bitwise boolean ufuncs should broadcast' + + +# === Broadcast errors === +try: + np.add(np.ones((2, 3)), np.ones((2,))) + assert False, 'expected incompatible broadcast to fail' +except ValueError as exc: + assert str(exc) == 'operands could not be broadcast together with shapes (2,3) (2,) ', ( + 'broadcast errors should match NumPy ufunc shape messages' + ) + +try: + np.broadcast_shapes((2, 3), (2,)) + assert False, 'expected incompatible broadcast_shapes input to fail' +except ValueError as exc: + assert str(exc) == ( + 'shape mismatch: objects cannot be broadcast to a single shape. ' + 'Mismatch is between arg 0 with shape (2, 3) and arg 1 with shape (2,).' + ), 'broadcast_shapes should match NumPy public helper shape messages' diff --git a/crates/monty/test_cases/numpy__comparison.py b/crates/monty/test_cases/numpy__comparison.py new file mode 100644 index 000000000..60aeeb847 --- /dev/null +++ b/crates/monty/test_cases/numpy__comparison.py @@ -0,0 +1,60 @@ +# skip-cpython +# === Boolean comparisons === +import numpy as np + +a = np.array([1, 2, 3, 4, 5]) + +# === Greater than === +mask = a > 3 +assert mask.tolist() == [False, False, False, True, True], 'gt mask' + +# === Less than === +mask2 = a < 3 +assert mask2.tolist() == [True, True, False, False, False], 'lt mask' + +# === Equal === +mask3 = a == 3 +assert mask3.tolist() == [False, False, True, False, False], 'eq mask' + +# === Greater than or equal === +mask4 = a >= 3 +assert mask4.tolist() == [False, False, True, True, True], 'gte mask' + +# === Less than or equal === +mask5 = a <= 3 +assert mask5.tolist() == [True, True, True, False, False], 'lte mask' + +# === Not equal === +mask6 = a != 3 +assert mask6.tolist() == [True, True, False, True, True], 'ne mask' + +# === Boolean indexing === +filtered = a[a > 3] +assert filtered.tolist() == [4, 5], 'boolean indexing' + +filtered2 = a[a <= 2] +assert filtered2.tolist() == [1, 2], 'boolean indexing lte' + +# === any / all === +assert (a > 0).all(), 'all positive' +assert not (a > 3).all(), 'not all gt 3' +assert (a > 3).any(), 'any gt 3' +assert not (a > 10).any(), 'none gt 10' + +# === Array-to-array comparisons === +x = np.array([1, 5, 3, 7, 2]) +y = np.array([2, 4, 3, 8, 1]) + +assert (x == y).tolist() == [False, False, True, False, False], 'arr == arr' +assert (x != y).tolist() == [True, True, False, True, True], 'arr != arr' +assert (x > y).tolist() == [False, True, False, False, True], 'arr > arr' +assert (x < y).tolist() == [True, False, False, True, False], 'arr < arr' +assert (x >= y).tolist() == [False, True, True, False, True], 'arr >= arr' +assert (x <= y).tolist() == [True, False, True, True, False], 'arr <= arr' + +# === Float comparisons === +fa = np.array([1.5, 2.5, 3.5]) +assert (fa > 2.0).tolist() == [False, True, True], 'float gt scalar' +assert (fa <= 2.5).tolist() == [True, True, False], 'float lte scalar' +assert (fa == 2.5).tolist() == [False, True, False], 'float eq scalar' +assert (fa != 2.5).tolist() == [True, False, True], 'float ne scalar' diff --git a/crates/monty/test_cases/numpy__math_functions.py b/crates/monty/test_cases/numpy__math_functions.py new file mode 100644 index 000000000..e4da59c19 --- /dev/null +++ b/crates/monty/test_cases/numpy__math_functions.py @@ -0,0 +1,180 @@ +# skip-cpython +# === NumPy math functions === +import numpy as np + +# === np.abs === +a = np.array([-1, -2, 3, -4, 5]) +result = np.abs(a) +assert result.tolist() == [1, 2, 3, 4, 5], 'np.abs' + +# === np.sqrt === +b = np.array([4, 9, 16, 25]) +result = np.sqrt(b) +assert result.tolist() == [2.0, 3.0, 4.0, 5.0], 'np.sqrt' + +# === np.log (natural log) === +c = np.array([1, 10, 100]) +log_result = np.log(c) +assert abs(log_result[0] - 0.0) < 0.001, 'np.log(1)' +assert abs(log_result[1] - 2.302585) < 0.001, 'np.log(10)' + +# === np.exp === +d = np.array([0, 1, 2]) +exp_result = np.exp(d) +assert abs(exp_result[0] - 1.0) < 0.001, 'np.exp(0)' +assert abs(exp_result[1] - 2.71828) < 0.001, 'np.exp(1)' + +# === np.round === +e = np.array([1.234, 2.567, 3.891]) +rounded = np.round(e, 1) +assert rounded.tolist() == [1.2, 2.6, 3.9], 'np.round' + +# === np.clip === +f = np.array([1, 5, 10, 15, 20]) +clipped = np.clip(f, 5, 15) +assert clipped.tolist() == [5, 5, 10, 15, 15], 'np.clip' + +# === np.where === +g = np.array([1, 2, 3, 4, 5]) +result = np.where(g > 3, g, 0) +assert result.tolist() == [0, 0, 0, 4, 5], 'np.where with arrays' + +result2 = np.where(g > 3, 1, 0) +assert result2.tolist() == [0, 0, 0, 1, 1], 'np.where with scalars' + +# === np.maximum / np.minimum (element-wise) === +x = np.array([1, 5, 3]) +y = np.array([2, 4, 6]) +assert np.maximum(x, y).tolist() == [2, 5, 6], 'np.maximum' +assert np.minimum(x, y).tolist() == [1, 4, 3], 'np.minimum' + +# === np.sort === +unsorted = np.array([3, 1, 4, 1, 5]) +assert np.sort(unsorted).tolist() == [1, 1, 3, 4, 5], 'np.sort' + +# === np.unique === +repeated = np.array([3, 1, 2, 1, 3, 2]) +assert np.unique(repeated).tolist() == [1, 2, 3], 'np.unique' + +# === np.concatenate === +arr1 = np.array([1, 2, 3]) +arr2 = np.array([4, 5, 6]) +combined = np.concatenate([arr1, arr2]) +assert combined.tolist() == [1, 2, 3, 4, 5, 6], 'np.concatenate' + +# === np.cumsum === +h = np.array([1, 2, 3, 4]) +assert np.cumsum(h).tolist() == [1, 3, 6, 10], 'np.cumsum' + +# === np.dot === +a1 = np.array([1, 2, 3]) +a2 = np.array([4, 5, 6]) +assert np.dot(a1, a2) == 32, 'np.dot' + +# === np.ceil / np.floor === +vals = np.array([1.2, 2.7, 3.5]) +assert np.ceil(vals).tolist() == [2.0, 3.0, 4.0], 'np.ceil' +assert np.floor(vals).tolist() == [1.0, 2.0, 3.0], 'np.floor' + +# === np.log10 === +log10_result = np.log10(np.array([1, 10, 100, 1000])) +assert log10_result[0] == 0.0, 'np.log10(1)' +assert log10_result[1] == 1.0, 'np.log10(10)' +assert log10_result[2] == 2.0, 'np.log10(100)' +assert log10_result[3] == 3.0, 'np.log10(1000)' + + +# === low-risk real and integer math ufuncs === +assert np.copysign(-2, 3) == 2.0, 'copysign scalar result' +assert np.copysign([-2, 3], [1, -1]).tolist() == [2.0, -3.0], 'copysign lists' + +mantissa, exponent = np.frexp(np.array([0.0, 8.0, -6.0])) +assert mantissa.tolist() == [0.0, 0.5, -0.75], 'frexp mantissas' +assert exponent.tolist() == [0, 4, 3], 'frexp exponents' +scalar_mantissa, scalar_exponent = np.frexp(8.0) +assert scalar_mantissa == 0.5, 'frexp scalar mantissa' +assert scalar_exponent == 4, 'frexp scalar exponent' + +fractional, integral = np.modf([-2.75, 3.25]) +assert fractional.tolist() == [-0.75, 0.25], 'modf fractional parts' +assert integral.tolist() == [-2.0, 3.0], 'modf integral parts' +scalar_fractional, scalar_integral = np.modf(-2.75) +assert scalar_fractional == -0.75, 'modf scalar fractional' +assert scalar_integral == -2.0, 'modf scalar integral' + +assert np.ldexp(0.5, 3) == 4.0, 'ldexp scalar' +assert np.ldexp([0.5, -1.5], [3, 2]).tolist() == [4.0, -6.0], 'ldexp lists' +assert np.gcd(-12, 18) == 6, 'gcd scalar' +assert np.gcd([12, -18], [8, 12]).tolist() == [4, 6], 'gcd lists' +assert np.gcd(True, 4) == 1, 'gcd bool scalar' +assert np.lcm(-4, 6) == 12, 'lcm scalar' +assert np.lcm([-4, 6], [6, 8]).tolist() == [12, 24], 'lcm lists' + +logadd = np.logaddexp([0.0, 1.0], [0.0, 2.0]) +assert abs(logadd[0] - 0.6931471805599453) < 1e-12, 'logaddexp equal inputs' +assert abs(logadd[1] - 2.313261687518223) < 1e-12, 'logaddexp offset inputs' +logadd2 = np.logaddexp2([0.0, 1.0], [0.0, 2.0]) +assert logadd2[0] == 1.0, 'logaddexp2 equal inputs' +assert abs(logadd2[1] - 2.584962500721156) < 1e-12, 'logaddexp2 offset inputs' + +assert np.nextafter(0.0, 1.0) == 5e-324, 'nextafter smallest subnormal' +assert np.nextafter([1.0], [2.0]).tolist() == [1.0000000000000002], 'nextafter lists' +assert np.spacing([0.0, 1.0, -1.0]).tolist() == [ + 5e-324, + 2.220446049250313e-16, + -2.220446049250313e-16, +], 'spacing signs' +assert np.signbit(np.array([0.0, -0.0, -2.0, 3.0])).tolist() == [ + False, + True, + True, + False, +], 'signbit array' + +sinc_result = np.sinc([0.0, 0.5, 1.0]) +assert sinc_result[0] == 1.0, 'sinc zero' +assert abs(sinc_result[1] - 0.6366197723675814) < 1e-12, 'sinc half' +assert abs(sinc_result[2]) < 1e-12, 'sinc one' +assert np.heaviside([-2.0, 0.0, 3.0], 0.5).tolist() == [0.0, 0.5, 1.0], 'heaviside list' +assert np.trunc([-2.75, 3.25]).tolist() == [-2.0, 3.0], 'trunc list' +assert np.fix([-2.75, 3.25]).tolist() == [-2.0, 3.0], 'fix list' +assert np.float_power([2, 4], [-1, 0.5]).tolist() == [0.5, 2.0], 'float_power lists' + +quotient, remainder = np.divmod(np.array([-3, 4]), np.array([2, 3])) +assert quotient.tolist() == [-2, 1], 'divmod quotient array' +assert remainder.tolist() == [1, 1], 'divmod remainder array' +scalar_quotient, scalar_remainder = np.divmod(7, 3) +assert scalar_quotient == 2, 'divmod scalar quotient' +assert scalar_remainder == 1, 'divmod scalar remainder' + + +# === window generators and Bessel i0 === +assert np.bartlett(0).tolist() == [], 'bartlett zero length' +assert np.bartlett(-3).tolist() == [], 'bartlett negative length' +assert np.bartlett(1).tolist() == [1.0], 'bartlett singleton' +assert np.bartlett(5).tolist() == [0.0, 0.5, 1.0, 0.5, 0.0], 'bartlett values' + +blackman = np.blackman(5) +assert abs(blackman[0]) < 1e-12, 'blackman first' +assert abs(blackman[1] - 0.34) < 1e-12, 'blackman second' +assert abs(blackman[2] - 1.0) < 1e-12, 'blackman center' + +hamming = np.hamming(5) +assert abs(hamming[0] - 0.08) < 1e-12, 'hamming first' +assert abs(hamming[1] - 0.54) < 1e-12, 'hamming second' +assert hamming[2] == 1.0, 'hamming center' +hanning = np.hanning(5) +assert abs(hanning[0]) < 1e-12, 'hanning first' +assert abs(hanning[1] - 0.5) < 1e-12, 'hanning second' +assert hanning[2] == 1.0, 'hanning center' +assert abs(hanning[3] - 0.5) < 1e-12, 'hanning fourth' +assert abs(hanning[4]) < 1e-12, 'hanning last' + +kaiser = np.kaiser(5, 2.0) +assert abs(kaiser[0] - 0.4386762798370488) < 1e-7, 'kaiser first' +assert abs(kaiser[1] - 0.8347614334011666) < 1e-7, 'kaiser second' +assert kaiser[2] == 1.0, 'kaiser center' + +assert np.i0(0.0) == 1.0, 'i0 zero' +assert abs(np.i0(1.0) - 1.2660658777520082) < 1e-7, 'i0 scalar' +assert abs(np.i0([0.0, 2.0])[1] - 2.279585302336067) < 1e-7, 'i0 list' diff --git a/crates/monty/test_cases/numpy__methods.py b/crates/monty/test_cases/numpy__methods.py new file mode 100644 index 000000000..41a02c375 --- /dev/null +++ b/crates/monty/test_cases/numpy__methods.py @@ -0,0 +1,157 @@ +# skip-cpython +# === Aggregation methods === +import numpy as np + +a = np.array([1, 2, 3, 4, 5]) + +# === sum === +assert a.sum() == 15, 'sum' +assert np.sum(a) == 15, 'np.sum' + +# === mean === +assert a.mean() == 3.0, 'mean' +assert np.mean(a) == 3.0, 'np.mean' + +# === min / max === +assert a.min() == 1, 'min' +assert a.max() == 5, 'max' +assert np.min(a) == 1, 'np.min' +assert np.max(a) == 5, 'np.max' + +# === std === +b = np.array([2, 4, 4, 4, 5, 5, 7, 9]) +assert b.mean() == 5.0, 'mean for std' +assert b.std() == 2.0, 'std' + +# === reshape === +c = np.array([1, 2, 3, 4, 5, 6]) +d = c.reshape(2, 3) +assert d.shape == (2, 3), 'reshape shape' +assert d[0][0] == 1, 'reshape [0][0]' +assert d[0][2] == 3, 'reshape [0][2]' +assert d[1][0] == 4, 'reshape [1][0]' +assert d[1][2] == 6, 'reshape [1][2]' + +# === flatten === +e = d.flatten() +assert e.shape == (6,), 'flatten shape' +assert e[0] == 1, 'flatten first' +assert e[5] == 6, 'flatten last' + +# === tolist === +f = np.array([1, 2, 3]) +result = f.tolist() +assert result == [1, 2, 3], 'tolist' +assert type(result) == list, 'tolist returns list' + +# === argmin / argmax === +g = np.array([3, 1, 4, 1, 5]) +assert g.argmin() == 1, 'argmin' +assert g.argmax() == 4, 'argmax' + +# === cumsum === +h = np.array([1, 2, 3, 4]) +cs = h.cumsum() +assert cs.tolist() == [1, 3, 6, 10], 'cumsum method' + +# === abs (via np.abs, not method) === +neg = np.array([-1, 2, -3]) +assert np.abs(neg).tolist() == [1, 2, 3], 'np.abs function' + +# === np.abs / np.sqrt / np.exp / np.ceil / np.floor on plain lists === +assert np.abs([-1, 2, -3]).tolist() == [1, 2, 3], 'np.abs(list)' +assert np.sqrt([1.0, 4.0, 9.0]).tolist() == [1.0, 2.0, 3.0], 'np.sqrt(list)' +assert np.exp([0.0]).tolist() == [1.0], 'np.exp(list)' +assert np.ceil([1.2, 2.7]).tolist() == [2.0, 3.0], 'np.ceil(list)' +assert np.floor([1.8, 2.3]).tolist() == [1.0, 2.0], 'np.floor(list)' + +# === round === +floats = np.array([1.234, 2.567, 3.891]) +assert floats.round(1).tolist() == [1.2, 2.6, 3.9], 'round method' + +# === clip === +arr = np.array([1, 5, 10, 15, 20]) +assert arr.clip(5, 15).tolist() == [5, 5, 10, 15, 15], 'clip method' + +# === sort method (returns new sorted array) === +unsorted = np.array([3, 1, 4, 1, 5]) +sorted_arr = np.sort(unsorted) +assert sorted_arr.tolist() == [1, 1, 3, 4, 5], 'sort' + +# === np.mean / np.sum / np.min / np.max on plain lists === +plain = [10, 20, 30, 40, 50] +assert np.mean(plain) == 30.0, 'np.mean(list)' +assert np.sum(plain) == 150, 'np.sum(list)' +assert np.min(plain) == 10, 'np.min(list)' +assert np.max(plain) == 50, 'np.max(list)' + +# Float list +flist = [1.5, 2.5, 3.5] +assert np.mean(flist) == 2.5, 'np.mean(float list)' +assert np.sum(flist) == 7.5, 'np.sum(float list)' + +# Single element +assert np.mean([42]) == 42.0, 'np.mean(single)' +assert np.sum([42]) == 42, 'np.sum(single)' + +# np.std on list +assert np.std([2, 4, 4, 4, 5, 5, 7, 9]) == 2.0, 'np.std(list)' + +# === ndarray attributes === +arr_attr = np.array([1, 2, 3, 4, 5]) +assert arr_attr.shape == (5,), 'ndarray shape 1d' +assert str(arr_attr.dtype) == 'int64', 'ndarray dtype int' + +arr_float = np.array([1.0, 2.0, 3.0]) +assert str(arr_float.dtype) == 'float64', 'ndarray dtype float' +assert arr_float.shape == (3,), 'ndarray float shape' + +# 2D array attributes +arr_2d = np.array([1, 2, 3, 4, 5, 6]).reshape(2, 3) +assert arr_2d.shape == (2, 3), 'ndarray shape 2d' + +# === unique === +dup = np.array([3, 1, 2, 1, 3, 2]) +u = np.unique(dup) +assert u.tolist() == [1, 2, 3], 'np.unique' + +# === concatenate === +c1 = np.array([1, 2, 3]) +c2 = np.array([4, 5, 6]) +cat = np.concatenate([c1, c2]) +assert cat.tolist() == [1, 2, 3, 4, 5, 6], 'np.concatenate' + +# === np.where === +cond = np.array([True, False, True, False]) +result_w = np.where(cond, 10, 20) +assert result_w.tolist() == [10, 20, 10, 20], 'np.where bool array' + +# === np.maximum / np.minimum === +m1 = np.array([1, 5, 3]) +m2 = np.array([4, 2, 6]) +assert np.maximum(m1, m2).tolist() == [4, 5, 6], 'np.maximum' +assert np.minimum(m1, m2).tolist() == [1, 2, 3], 'np.minimum' + +# === ndarray repr === +assert repr(np.array([1, 2, 3])) == 'array([1, 2, 3])', 'ndarray int repr' +assert repr(np.array([1.5, 2.5])) == 'array([1.5, 2.5])', 'ndarray float repr' + +# === ndarray bool (truthiness) === +# numpy only allows bool() on single-element arrays +assert bool(np.array([1])) == True, 'single-element truthy' +assert bool(np.array([0])) == False, 'single-element falsy' + +# === ndarray len === +assert len(np.array([1, 2, 3])) == 3, 'ndarray len' + +# === ndarray type === +assert type(np.array([1])).__name__ == 'ndarray', 'ndarray type name' + +# === np.where with array x and scalar y === +cond2 = np.array([True, False, True]) +arr_x = np.array([10, 20, 30]) +result_w2 = np.where(cond2, arr_x, 0) +assert result_w2.tolist() == [10, 0, 30], 'np.where array x scalar y' + +# === np.cumsum on array === +assert np.cumsum(np.array([1, 2, 3])).tolist() == [1, 3, 6], 'np.cumsum on array' diff --git a/crates/monty/test_cases/numpy__parity.py b/crates/monty/test_cases/numpy__parity.py new file mode 100644 index 000000000..312d6b718 --- /dev/null +++ b/crates/monty/test_cases/numpy__parity.py @@ -0,0 +1,2865 @@ +# skip-cpython +import numpy as np + +# ============================================================ +# 1. ARRAY CREATION FUNCTIONS +# ============================================================ + +# === np.array === +# int array from list +a = np.array([1, 2, 3]) +assert a.tolist() == [1, 2, 3], 'np.array int list' +assert a.dtype == 'int64', 'np.array int dtype' + +# float array from list +a = np.array([1.0, 2.0, 3.0]) +assert a.tolist() == [1.0, 2.0, 3.0], 'np.array float list' +assert a.dtype == 'float64', 'np.array float dtype' + +# mixed int/float promotes to float +a = np.array([1, 2.0, 3]) +assert a.tolist() == [1.0, 2.0, 3.0], 'np.array mixed promotes to float' +assert a.dtype == 'float64', 'np.array mixed dtype' + +# single element +a = np.array([42]) +assert a.tolist() == [42], 'np.array single int' + +a = np.array([3.14]) +assert a.tolist() == [3.14], 'np.array single float' + +# 2D array +a = np.array([[1, 2], [3, 4]]) +assert a.shape == (2, 2), 'np.array 2D shape' +assert a.dtype == 'int64', 'np.array 2D int dtype' + +# 2D float +a = np.array([[1.0, 2.0], [3.0, 4.0]]) +assert a.shape == (2, 2), 'np.array 2D float shape' +assert a.dtype == 'float64', 'np.array 2D float dtype' + +# === np.zeros === +a = np.zeros(3) +assert a.tolist() == [0.0, 0.0, 0.0], 'np.zeros values' +assert a.dtype == 'float64', 'np.zeros dtype' +assert len(a) == 3, 'np.zeros len' + +a = np.zeros(1) +assert a.tolist() == [0.0], 'np.zeros single' + +# === np.ones === +a = np.ones(3) +assert a.tolist() == [1.0, 1.0, 1.0], 'np.ones values' +assert a.dtype == 'float64', 'np.ones dtype' + +a = np.ones(1) +assert a.tolist() == [1.0], 'np.ones single' + +# === np.arange === +# single arg (stop) +a = np.arange(5) +assert a.tolist() == [0, 1, 2, 3, 4], 'np.arange(5)' +assert a.dtype == 'int64', 'np.arange int dtype' + +# two args (start, stop) +a = np.arange(2, 6) +assert a.tolist() == [2, 3, 4, 5], 'np.arange(2, 6)' + +# three args (start, stop, step) +a = np.arange(0, 10, 2) +assert a.tolist() == [0, 2, 4, 6, 8], 'np.arange(0, 10, 2)' + +# float step +a = np.arange(0, 1, 0.5) +assert a.dtype == 'float64', 'np.arange float step dtype' +assert len(a) == 2, 'np.arange float step len' + +# negative step +a = np.arange(5, 0, -1) +assert a.tolist() == [5, 4, 3, 2, 1], 'np.arange negative step' + +# empty result +a = np.arange(5, 0) +assert a.tolist() == [], 'np.arange empty result' +assert len(a) == 0, 'np.arange empty len' + +# === np.linspace === +a = np.linspace(0, 1, 5) +assert a.dtype == 'float64', 'np.linspace dtype' +assert len(a) == 5, 'np.linspace len' +assert a[0] == 0.0, 'np.linspace start' +assert a[-1] == 1.0, 'np.linspace end' +# check intermediate values with rounding +assert round(a[1], 2) == 0.25, 'np.linspace[1]' +assert round(a[2], 2) == 0.5, 'np.linspace[2]' + +# linspace single point +a = np.linspace(5, 5, 1) +assert a.tolist() == [5.0], 'np.linspace single point' + +# linspace two points +a = np.linspace(0, 10, 2) +assert a.tolist() == [0.0, 10.0], 'np.linspace two points' + + +# ============================================================ +# 2. MODULE-LEVEL AGGREGATE FUNCTIONS +# ============================================================ + +# === np.sum === +# Note: module-level np.sum returns float in our impl +a_int = np.array([1, 2, 3]) +a_float = np.array([1.0, 2.0, 3.0]) + +# Method-level sum preserves dtype +assert a_int.sum() == 6, 'arr.sum() int value' +assert a_float.sum() == 6.0, 'arr.sum() float value' + +# === np.mean === +assert np.mean(np.array([1, 2, 3])) == 2.0, 'np.mean int array' +assert np.mean(np.array([1.0, 2.0, 3.0])) == 2.0, 'np.mean float array' +assert np.array([10]).mean() == 10.0, 'mean single element' +assert np.array([2, 4]).mean() == 3.0, 'mean two elements' + +# === np.min === +assert np.array([3, 1, 2]).min() == 1, 'arr.min() int' +assert np.array([3.0, 1.0, 2.0]).min() == 1.0, 'arr.min() float' +assert np.array([42]).min() == 42, 'min single element' +assert np.array([-5, -1, -10]).min() == -10, 'min negative' + +# === np.max === +assert np.array([3, 1, 2]).max() == 3, 'arr.max() int' +assert np.array([3.0, 1.0, 2.0]).max() == 3.0, 'arr.max() float' +assert np.array([42]).max() == 42, 'max single element' +assert np.array([-5, -1, -10]).max() == -1, 'max negative' + +# === np.std === +a = np.array([1, 2, 3, 4, 5]) +s = np.std(a) +assert round(s, 10) == round(1.4142135623730951, 10), 'np.std value' +a = np.array([5, 5, 5]) +assert np.std(a) == 0.0, 'np.std uniform' +assert np.array([10]).std() == 0.0, 'std single element' + + +# ============================================================ +# 3. MODULE-LEVEL ELEMENT-WISE FUNCTIONS +# ============================================================ + +# === np.abs === +a = np.abs(np.array([-1, -2, 3])) +assert a.tolist() == [1, 2, 3], 'np.abs int values' + +a = np.abs(np.array([-1.5, 2.5, -3.5])) +assert a.tolist() == [1.5, 2.5, 3.5], 'np.abs float values' + +a = np.abs(np.array([0])) +assert a.tolist() == [0], 'np.abs zero' + +# === np.sqrt === +a = np.sqrt(np.array([1.0, 4.0, 9.0])) +assert a.tolist() == [1.0, 2.0, 3.0], 'np.sqrt values' +assert a.dtype == 'float64', 'np.sqrt dtype' + +a = np.sqrt(np.array([0.0])) +assert a.tolist() == [0.0], 'np.sqrt zero' + +# === np.log === +a = np.log(np.array([1.0])) +assert a.tolist() == [0.0], 'np.log(1) = 0' + +# === np.exp === +a = np.exp(np.array([0.0])) +assert a.tolist() == [1.0], 'np.exp(0) = 1' +assert a.dtype == 'float64', 'np.exp dtype' + +# === np.ceil === +a = np.ceil(np.array([1.2, 2.7, -0.5])) +assert a.tolist() == [2.0, 3.0, 0.0], 'np.ceil float values' +assert a.dtype == 'float64', 'np.ceil dtype' + +# === np.floor === +a = np.floor(np.array([1.2, 2.7, -0.5])) +assert a.tolist() == [1.0, 2.0, -1.0], 'np.floor float values' +assert a.dtype == 'float64', 'np.floor dtype' + +# === np.log10 === +a = np.log10(np.array([1.0, 10.0, 100.0])) +assert a.tolist() == [0.0, 1.0, 2.0], 'np.log10 values' +assert a.dtype == 'float64', 'np.log10 dtype' + +# === np.round === +a = np.round(np.array([1.5, 2.5, 3.5])) +# numpy uses banker's rounding +r = a.tolist() +assert r[0] == 2.0, 'np.round 1.5' +assert r[2] == 4.0, 'np.round 3.5' + +a = np.round(np.array([1.234, 5.678]), 2) +r = a.tolist() +assert r[0] == 1.23, 'np.round decimals=2 first' +assert r[1] == 5.68, 'np.round decimals=2 second' + +# round with 0 decimals +a = np.round(np.array([1.6, 2.4])) +assert a.tolist() == [2.0, 2.0], 'np.round default decimals' + +# === np.clip === +a = np.clip(np.array([1, 5, 10, 15, 20]), 5, 15) +assert a.tolist() == [5, 5, 10, 15, 15], 'np.clip int values' + +a = np.clip(np.array([1.0, 5.0, 10.0]), 2.0, 8.0) +assert a.tolist() == [2.0, 5.0, 8.0], 'np.clip float values' + +a = np.clip(np.array([-10, 0, 10]), -5, 5) +assert a.tolist() == [-5, 0, 5], 'np.clip with negatives' + + +# ============================================================ +# 4. MODULE-LEVEL BINARY/SELECTION FUNCTIONS +# ============================================================ + +# === np.where === +cond = np.array([1, 0, 1, 0, 1]) +x = np.array([10, 20, 30, 40, 50]) +y = np.array([1, 2, 3, 4, 5]) +result = np.where(cond, x, y) +assert result.tolist() == [10, 2, 30, 4, 50], 'np.where basic' + +# where with scalar x, y +cond = np.array([1, 0, 1]) +result = np.where(cond, 10, 20) +assert result.tolist() == [10, 20, 10], 'np.where scalar x, y' + +# where with boolean condition +cond = np.array([1, 0, 1]) +result = np.where(cond, np.array([1, 2, 3]), np.array([4, 5, 6])) +assert result.tolist() == [1, 5, 3], 'np.where from comparison' + +# === np.maximum === +a = np.array([1, 3, 2]) +b = np.array([3, 1, 4]) +assert np.maximum(a, b).tolist() == [3, 3, 4], 'np.maximum values' + +# int preserves dtype +a = np.array([1, 5]) +b = np.array([3, 2]) +assert np.maximum(a, b).dtype == 'int64', 'np.maximum int dtype' + +# float +a = np.array([1.0, 3.0]) +b = np.array([3.0, 1.0]) +assert np.maximum(a, b).tolist() == [3.0, 3.0], 'np.maximum float' + +# === np.minimum === +a = np.array([1, 3, 2]) +b = np.array([3, 1, 4]) +assert np.minimum(a, b).tolist() == [1, 1, 2], 'np.minimum values' + +a = np.array([1, 5]) +b = np.array([3, 2]) +assert np.minimum(a, b).dtype == 'int64', 'np.minimum int dtype' + +# === np.sort === +a = np.sort(np.array([3, 1, 2])) +assert a.tolist() == [1, 2, 3], 'np.sort int' +assert a.dtype == 'int64', 'np.sort int dtype' + +a = np.sort(np.array([3.0, 1.0, 2.0])) +assert a.tolist() == [1.0, 2.0, 3.0], 'np.sort float' + +# already sorted +a = np.sort(np.array([1, 2, 3])) +assert a.tolist() == [1, 2, 3], 'np.sort already sorted' + +# reverse sorted +a = np.sort(np.array([5, 4, 3, 2, 1])) +assert a.tolist() == [1, 2, 3, 4, 5], 'np.sort reverse' + +# single element +a = np.sort(np.array([42])) +assert a.tolist() == [42], 'np.sort single' + +# === np.unique === +a = np.unique(np.array([3, 1, 2, 1, 3])) +assert a.tolist() == [1, 2, 3], 'np.unique values' +assert a.dtype == 'int64', 'np.unique int dtype' + +a = np.unique(np.array([1, 1, 1])) +assert a.tolist() == [1], 'np.unique all same' + +a = np.unique(np.array([5])) +assert a.tolist() == [5], 'np.unique single' + +# === np.concatenate === +a = np.concatenate([np.array([1, 2]), np.array([3, 4])]) +assert a.tolist() == [1, 2, 3, 4], 'np.concatenate two arrays' +assert a.dtype == 'int64', 'np.concatenate int dtype' + +# three arrays +a = np.concatenate([np.array([1]), np.array([2]), np.array([3])]) +assert a.tolist() == [1, 2, 3], 'np.concatenate three arrays' + +# mixed dtypes +a = np.concatenate([np.array([1, 2]), np.array([3.0, 4.0])]) +assert a.tolist() == [1.0, 2.0, 3.0, 4.0], 'np.concatenate mixed dtype' +assert a.dtype == 'float64', 'np.concatenate mixed dtype result' + +# === np.cumsum === +a = np.cumsum(np.array([1, 2, 3])) +assert a.tolist() == [1, 3, 6], 'np.cumsum int' +assert a.dtype == 'int64', 'np.cumsum int dtype' + +a = np.cumsum(np.array([1.0, 2.0, 3.0])) +assert a.tolist() == [1.0, 3.0, 6.0], 'np.cumsum float' + +a = np.cumsum(np.array([10])) +assert a.tolist() == [10], 'np.cumsum single' + +# === np.dot === +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +assert np.dot(a, b) == 32, 'np.dot int result' + +a = np.array([1.0, 2.0, 3.0]) +b = np.array([4.0, 5.0, 6.0]) +assert np.dot(a, b) == 32.0, 'np.dot float result' + +# single element +assert np.dot(np.array([5]), np.array([3])) == 15, 'np.dot single element' + + +# ============================================================ +# 5. NDARRAY METHODS +# ============================================================ + +# === .sum() === +assert np.array([1, 2, 3]).sum() == 6, 'sum int' +assert np.array([1.0, 2.0, 3.0]).sum() == 6.0, 'sum float' +assert np.array([100]).sum() == 100, 'sum single' + +# === .mean() === +assert np.array([2, 4, 6]).mean() == 4.0, 'mean int' +assert np.array([1.0, 3.0]).mean() == 2.0, 'mean float' +assert np.array([7]).mean() == 7.0, 'mean single' + +# === .min() === +assert np.array([3, 1, 2]).min() == 1, 'min int' +assert np.array([3.0, 1.0, 2.0]).min() == 1.0, 'min float' +assert np.array([-10]).min() == -10, 'min single negative' + +# === .max() === +assert np.array([3, 1, 2]).max() == 3, 'max int' +assert np.array([3.0, 1.0, 2.0]).max() == 3.0, 'max float' +assert np.array([0]).max() == 0, 'max single zero' + +# === .std() === +assert np.array([1, 1, 1]).std() == 0.0, 'std uniform' +s = np.array([1, 2, 3, 4, 5]).std() +assert round(s, 10) == round(1.4142135623730951, 10), 'std five elements' + +# === .flatten() === +a = np.array([[1, 2], [3, 4]]) +f = a.flatten() +assert f.tolist() == [1, 2, 3, 4], 'flatten 2D' +assert f.shape == (4,), 'flatten shape' +assert f.dtype == 'int64', 'flatten preserves dtype' + +# 1D flatten is identity +a = np.array([1, 2, 3]) +assert a.flatten().tolist() == [1, 2, 3], 'flatten 1D' + +# === .tolist() === +assert np.array([1, 2, 3]).tolist() == [1, 2, 3], 'tolist int' +assert np.array([1.0, 2.0]).tolist() == [1.0, 2.0], 'tolist float' +assert np.array([42]).tolist() == [42], 'tolist single' + +# === .copy() === +a = np.array([1, 2, 3]) +b = a.copy() +assert b.tolist() == [1, 2, 3], 'copy values' +assert b.dtype == 'int64', 'copy preserves dtype' + +a = np.array([1.0, 2.0]) +b = a.copy() +assert b.tolist() == [1.0, 2.0], 'copy float values' + +# === .sort() (method) === +# In numpy, sort() is in-place, returns None +# In our impl, sort() returns new sorted array (we test both ways) +original = np.array([3, 1, 2]) +sorted_arr = np.sort(original) # module-level returns new copy +assert sorted_arr.tolist() == [1, 2, 3], 'np.sort returns sorted' + +# === .argsort() === +a = np.array([3, 1, 2]) +idx = a.argsort() +assert idx.tolist() == [1, 2, 0], 'argsort values' +assert idx.dtype == 'int64', 'argsort dtype' + +a = np.array([10, 30, 20]) +assert a.argsort().tolist() == [0, 2, 1], 'argsort three elements' + +a = np.array([1]) +assert a.argsort().tolist() == [0], 'argsort single' + +# === .argmin() === +assert np.array([3, 1, 2]).argmin() == 1, 'argmin basic' +assert np.array([10]).argmin() == 0, 'argmin single' +assert np.array([5, 5, 5]).argmin() == 0, 'argmin ties' + +# === .argmax() === +assert np.array([3, 1, 2]).argmax() == 0, 'argmax basic' +assert np.array([10]).argmax() == 0, 'argmax single' +assert np.array([5, 5, 5]).argmax() == 0, 'argmax ties' + +# === .all() === +assert np.array([1, 2, 3]).all() == True, 'all truthy' +assert np.array([1, 0, 3]).all() == False, 'all with zero' +assert np.array([1]).all() == True, 'all single truthy' +assert np.array([0]).all() == False, 'all single falsy' + +# === .any() === +assert np.array([0, 0, 1]).any() == True, 'any with one truthy' +assert np.array([0, 0, 0]).any() == False, 'any all falsy' +assert np.array([1]).any() == True, 'any single truthy' +assert np.array([0]).any() == False, 'any single falsy' + +# === .cumsum() === +a = np.array([1, 2, 3]).cumsum() +assert a.tolist() == [1, 3, 6], 'cumsum int' +assert a.dtype == 'int64', 'cumsum int dtype' + +a = np.array([1.0, 2.0, 3.0]).cumsum() +assert a.tolist() == [1.0, 3.0, 6.0], 'cumsum float' + +assert np.array([42]).cumsum().tolist() == [42], 'cumsum single' + +# === .reshape() === +a = np.array([1, 2, 3, 4, 5, 6]) +b = a.reshape(2, 3) +assert b.shape == (2, 3), 'reshape shape' +assert b.dtype == 'int64', 'reshape preserves dtype' + +b = a.reshape(3, 2) +assert b.shape == (3, 2), 'reshape 3x2' + +b = a.reshape(6) +assert b.shape == (6,), 'reshape to 1D' + +# === .round() === +a = np.array([1.234, 5.678]).round(2) +assert a.tolist() == [1.23, 5.68], 'round method decimals=2' + +a = np.array([1.5, 2.5]).round() +r = a.tolist() +assert r[0] == 2.0, 'round method 1.5' + +# === .clip() === +a = np.array([1, 5, 10]).clip(3, 8) +assert a.tolist() == [3, 5, 8], 'clip method basic' + +a = np.array([1.0, 5.0, 10.0]).clip(2.0, 8.0) +assert a.tolist() == [2.0, 5.0, 8.0], 'clip method float' + +# === .dot() === +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +assert a.dot(b) == 32, 'dot method int' + +a = np.array([1.0, 2.0]) +b = np.array([3.0, 4.0]) +assert a.dot(b) == 11.0, 'dot method float' + +# === .astype() === +a = np.array([1.5, 2.7, 3.1]).astype('int64') +assert a.tolist() == [1, 2, 3], 'astype to int64' +assert a.dtype == 'int64', 'astype int64 dtype' + +a = np.array([1, 2, 3]).astype('float64') +assert a.tolist() == [1.0, 2.0, 3.0], 'astype to float64' +assert a.dtype == 'float64', 'astype float64 dtype' + + +# ============================================================ +# 6. NDARRAY ATTRIBUTES +# ============================================================ + +# === .shape === +assert np.array([1, 2, 3]).shape == (3,), 'shape 1D' +assert np.array([[1, 2], [3, 4]]).shape == (2, 2), 'shape 2D' +assert np.array([[1, 2, 3]]).shape == (1, 3), 'shape 1x3' +assert np.array([42]).shape == (1,), 'shape single' + +# === .dtype === +assert np.array([1, 2]).dtype == 'int64', 'dtype int' +assert np.array([1.0, 2.0]).dtype == 'float64', 'dtype float' +assert np.zeros(2).dtype == 'float64', 'zeros dtype' +assert np.ones(2).dtype == 'float64', 'ones dtype' +assert np.arange(3).dtype == 'int64', 'arange dtype' +assert np.linspace(0, 1, 3).dtype == 'float64', 'linspace dtype' + +# === .size === +assert np.array([1, 2, 3]).size == 3, 'size 1D' +assert np.array([[1, 2], [3, 4]]).size == 4, 'size 2D' +assert np.array([42]).size == 1, 'size single' + +# === .ndim === +assert np.array([1, 2, 3]).ndim == 1, 'ndim 1D' +assert np.array([[1, 2], [3, 4]]).ndim == 2, 'ndim 2D' +assert np.array([42]).ndim == 1, 'ndim single element' + +# === .T === +a = np.array([[1, 2], [3, 4]]) +t = a.T +assert t.shape == (2, 2), 'T 2D shape' +# T[0] should be the first column: [1, 3] +assert t[0].tolist() == [1, 3], 'T first column' +assert t[1].tolist() == [2, 4], 'T second column' + +# 1D transpose is identity +a = np.array([1, 2, 3]) +assert a.T.tolist() == [1, 2, 3], 'T 1D identity' +assert a.T.shape == (3,), 'T 1D shape' + + +# ============================================================ +# 7. ELEMENT-WISE BINARY OPERATIONS +# ============================================================ + +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) + +# === array + array === +r = a + b +assert r.tolist() == [5, 7, 9], 'add array+array' +assert r.dtype == 'int64', 'add int+int dtype' + +# === array - array === +r = b - a +assert r.tolist() == [3, 3, 3], 'sub array-array' + +# === array * array === +r = a * b +assert r.tolist() == [4, 10, 18], 'mul array*array' + +# === array / array === +r = np.array([4, 6, 8]) / np.array([2, 3, 4]) +assert r.tolist() == [2.0, 2.0, 2.0], 'div array/array' +assert r.dtype == 'float64', 'div always float' + +# === array // array === +r = np.array([7, 8, 9]) // np.array([2, 3, 4]) +assert r.tolist() == [3, 2, 2], 'floordiv array//array' + +# === array % array === +r = np.array([7, 8, 9]) % np.array([3, 3, 4]) +assert r.tolist() == [1, 2, 1], 'mod array%array' + +# === array ** array === +r = np.array([2, 3, 4]) ** np.array([3, 2, 1]) +assert r.tolist() == [8, 9, 4], 'pow array**array' + +# === array + scalar === +r = np.array([1, 2, 3]) + 10 +assert r.tolist() == [11, 12, 13], 'add array+scalar' +assert r.dtype == 'int64', 'add int+int_scalar dtype' + +# === scalar + array === +r = 10 + np.array([1, 2, 3]) +assert r.tolist() == [11, 12, 13], 'add scalar+array' + +# === array - scalar === +r = np.array([10, 20, 30]) - 5 +assert r.tolist() == [5, 15, 25], 'sub array-scalar' + +# === scalar - array === +r = 10 - np.array([1, 2, 3]) +assert r.tolist() == [9, 8, 7], 'sub scalar-array' + +# === array * scalar === +r = np.array([1, 2, 3]) * 2 +assert r.tolist() == [2, 4, 6], 'mul array*scalar' + +# === scalar * array === +r = 2 * np.array([1, 2, 3]) +assert r.tolist() == [2, 4, 6], 'mul scalar*array' + +# === array / scalar === +r = np.array([2, 4, 6]) / 2 +assert r.tolist() == [1.0, 2.0, 3.0], 'div array/scalar' +assert r.dtype == 'float64', 'div result always float' + +# === scalar / array === +r = 12 / np.array([1, 2, 3]) +assert r.tolist() == [12.0, 6.0, 4.0], 'div scalar/array' + +# === array // scalar === +r = np.array([7, 8, 9]) // 3 +assert r.tolist() == [2, 2, 3], 'floordiv array//scalar' + +# === scalar // array === +r = 10 // np.array([3, 4, 5]) +assert r.tolist() == [3, 2, 2], 'floordiv scalar//array' + +# === array % scalar === +r = np.array([7, 8, 9]) % 3 +assert r.tolist() == [1, 2, 0], 'mod array%scalar' + +# === scalar % array === +r = 10 % np.array([3, 4, 7]) +assert r.tolist() == [1, 2, 3], 'mod scalar%array' + +# === array ** scalar === +r = np.array([2, 3, 4]) ** 2 +assert r.tolist() == [4, 9, 16], 'pow array**scalar' + +# === scalar ** array === +r = 2 ** np.array([1, 2, 3]) +assert r.tolist() == [2, 4, 8], 'pow scalar**array' + +# === mixed int/float arithmetic === +r = np.array([1, 2, 3]) + 0.5 +assert r.tolist() == [1.5, 2.5, 3.5], 'add int_array + float_scalar' +assert r.dtype == 'float64', 'int+float promotes to float' + +r = np.array([1, 2, 3]) + np.array([0.5, 0.5, 0.5]) +assert r.tolist() == [1.5, 2.5, 3.5], 'add int_array + float_array' +assert r.dtype == 'float64', 'int_arr+float_arr promotes to float' + +r = np.array([1.0, 2.0]) * 2 +assert r.tolist() == [2.0, 4.0], 'mul float_array * int_scalar' +assert r.dtype == 'float64', 'float*int stays float' + + +# ============================================================ +# 8. ELEMENT-WISE COMPARISONS +# ============================================================ + +a = np.array([1, 2, 3, 4, 5]) + +# === array > scalar === +r = a > 3 +assert r.tolist() == [False, False, False, True, True], 'gt scalar' + +# === array < scalar === +r = a < 3 +assert r.tolist() == [True, True, False, False, False], 'lt scalar' + +# === array >= scalar === +r = a >= 3 +assert r.tolist() == [False, False, True, True, True], 'gte scalar' + +# === array <= scalar === +r = a <= 3 +assert r.tolist() == [True, True, True, False, False], 'lte scalar' + +# === array == scalar === +r = a == 3 +assert r.tolist() == [False, False, True, False, False], 'eq scalar' + +# === array != scalar === +r = a != 3 +assert r.tolist() == [True, True, False, True, True], 'ne scalar' + +# === array vs array comparisons === +x = np.array([1, 3, 5]) +y = np.array([2, 3, 4]) + +assert (x > y).tolist() == [False, False, True], 'gt array' +assert (x < y).tolist() == [True, False, False], 'lt array' +assert (x >= y).tolist() == [False, True, True], 'gte array' +assert (x <= y).tolist() == [True, True, False], 'lte array' +assert (x == y).tolist() == [False, True, False], 'eq array' +assert (x != y).tolist() == [True, False, True], 'ne array' + +# comparison result dtype is bool +r = np.array([1, 2]) > np.array([0, 3]) +assert r.dtype == 'bool', 'comparison dtype is bool' + + +# ============================================================ +# 9. UNARY NEGATION +# ============================================================ + +# int negation +a = -np.array([1, 2, 3]) +assert a.tolist() == [-1, -2, -3], 'neg int' +assert a.dtype == 'int64', 'neg int preserves dtype' + +# float negation +a = -np.array([1.5, -2.5, 0.0]) +assert a.tolist() == [-1.5, 2.5, 0.0], 'neg float' +assert a.dtype == 'float64', 'neg float preserves dtype' + +# double negation +a = -(-np.array([1, 2, 3])) +assert a.tolist() == [1, 2, 3], 'double neg' + +# negation of zeros +a = -np.array([0, 0, 0]) +assert a.tolist() == [0, 0, 0], 'neg zeros' + +# === bitwise invert (~) === +# int invert: ~n = -(n+1) +a = ~np.array([0, 1, 2, -1]) +assert a.tolist() == [-1, -2, -3, 0], 'invert int' +assert a.dtype == 'int64', 'invert int preserves dtype' + +# bool invert: flips True/False +b = ~np.array([True, False, True]) +assert b.tolist() == [False, True, False], 'invert bool' +assert b.dtype == 'bool', 'invert bool preserves dtype' + +# === np.where shape validation === +# matching shapes should work +cond = np.array([True, False, True]) +x = np.array([10, 20, 30]) +y = np.array([0, 0, 0]) +assert np.where(cond, x, y).tolist() == [10, 0, 30], 'where matching shapes' + +# scalar x/y should broadcast +assert np.where(cond, 1, 0).tolist() == [1, 0, 1], 'where scalar broadcast' + + +# ============================================================ +# 10. REPR FORMAT +# ============================================================ + +# int array repr +assert repr(np.array([1, 2, 3])) == 'array([1, 2, 3])', 'repr int array' + +# float array repr +assert repr(np.array([1.0, 2.0, 3.0])) == 'array([1., 2., 3.])', 'repr float array' + +# single element +assert repr(np.array([42])) == 'array([42])', 'repr single int' +assert repr(np.array([3.14])) == 'array([3.14])', 'repr single float' + +# Note: 2D repr differs between our impl (single line) and real numpy (multi-line) +# so we skip 2D repr comparison here + + +# ============================================================ +# 11. TYPE AND TYPE NAME +# ============================================================ + +a = np.array([1, 2, 3]) +assert type(a).__name__ == 'ndarray', 'type name' + + +# ============================================================ +# 12. LEN +# ============================================================ + +assert len(np.array([1, 2, 3])) == 3, 'len 1D' +assert len(np.array([42])) == 1, 'len single' +assert len(np.array([[1, 2], [3, 4]])) == 2, 'len 2D (num rows)' +assert len(np.zeros(5)) == 5, 'len zeros' + + +# ============================================================ +# 13. INDEXING (GETITEM) +# ============================================================ + +a = np.array([10, 20, 30, 40, 50]) + +# positive int index +assert a[0] == 10, 'getitem [0]' +assert a[1] == 20, 'getitem [1]' +assert a[4] == 50, 'getitem [4]' + +# negative int index +assert a[-1] == 50, 'getitem [-1]' +assert a[-2] == 40, 'getitem [-2]' +assert a[-5] == 10, 'getitem [-5]' + +# float array indexing +b = np.array([1.5, 2.5, 3.5]) +assert b[0] == 1.5, 'getitem float [0]' +assert b[-1] == 3.5, 'getitem float [-1]' + +# 2D indexing (returns row) +a = np.array([[1, 2, 3], [4, 5, 6]]) +row0 = a[0] +assert row0.tolist() == [1, 2, 3], 'getitem 2D row 0' +row1 = a[1] +assert row1.tolist() == [4, 5, 6], 'getitem 2D row 1' +assert a[-1].tolist() == [4, 5, 6], 'getitem 2D negative index' + +# chained 2D indexing +assert a[0][1] == 2, 'getitem 2D chained' +assert a[1][2] == 6, 'getitem 2D chained last' + +# boolean mask indexing +a = np.array([10, 20, 30, 40, 50]) +mask = np.array([1, 0, 1, 0, 1]) +result = a[mask > 0] +assert result.tolist() == [10, 30, 50], 'boolean mask indexing' + +# comparison-based boolean indexing +a = np.array([1, 2, 3, 4, 5]) +result = a[a > 3] +assert result.tolist() == [4, 5], 'comparison boolean indexing' + +result = a[a <= 2] +assert result.tolist() == [1, 2], 'comparison boolean indexing lte' + + +# ============================================================ +# 14. EDGE CASES +# ============================================================ + +# Large values +a = np.array([1000000, 2000000, 3000000]) +assert a.sum() == 6000000, 'large values sum' + +# Negative values throughout +a = np.array([-1, -2, -3]) +assert a.sum() == -6, 'negative sum' +assert a.mean() == -2.0, 'negative mean' + +# Single element operations +a = np.array([42]) +assert a.sum() == 42, 'single sum' +assert a.mean() == 42.0, 'single mean' +assert a.min() == 42, 'single min' +assert a.max() == 42, 'single max' +assert a.std() == 0.0, 'single std' + +# Zeros array operations +a = np.zeros(3) +assert a.sum() == 0.0, 'zeros sum' +assert a.mean() == 0.0, 'zeros mean' +assert a.min() == 0.0, 'zeros min' +assert a.max() == 0.0, 'zeros max' +assert a.std() == 0.0, 'zeros std' + +# Ones array operations +a = np.ones(3) +assert a.sum() == 3.0, 'ones sum' +assert a.mean() == 1.0, 'ones mean' +assert a.min() == 1.0, 'ones min' +assert a.max() == 1.0, 'ones max' +assert a.std() == 0.0, 'ones std' + +# Mixed positive and negative +a = np.array([-2, -1, 0, 1, 2]) +assert a.sum() == 0, 'mixed sum zero' +assert a.mean() == 0.0, 'mixed mean zero' +assert a.min() == -2, 'mixed min' +assert a.max() == 2, 'mixed max' + + +# ============================================================ +# 15. NaN AND INF EDGE CASES +# ============================================================ +import math + +# === Division by zero === +a = np.array([1.0, 2.0, 3.0]) / 0 +assert math.isinf(a[0]) and a[0] > 0, 'float / 0 = inf' +assert math.isinf(a[1]) and a[1] > 0, 'float / 0 = inf (2)' + +b = np.array([0.0]) / 0 +assert math.isnan(b[0]), '0.0 / 0 = nan' + +# === NaN propagation in aggregation === +nan_arr = np.array([1.0, float('nan'), 3.0]) +assert math.isnan(nan_arr.sum()), 'sum propagates nan' +assert math.isnan(nan_arr.mean()), 'mean propagates nan' +assert math.isnan(nan_arr.min()), 'min propagates nan' +assert math.isnan(nan_arr.max()), 'max propagates nan' +assert math.isnan(nan_arr.std()), 'std propagates nan' + +# argmin/argmax with NaN — NumPy returns index of first NaN +assert np.array([float('nan')]).argmin() == 0, 'argmin single nan' +assert np.array([float('nan')]).argmax() == 0, 'argmax single nan' +assert np.array([float('nan'), 1.0, 2.0]).argmin() == 0, 'argmin nan first' + +# === Inf operations === +inf_arr = np.array([float('inf')]) +assert (inf_arr + 1)[0] == float('inf'), 'inf + 1 = inf' +assert (inf_arr * -1)[0] == float('-inf'), 'inf * -1 = -inf' +assert math.isnan((inf_arr - inf_arr)[0]), 'inf - inf = nan' +assert inf_arr.sum() == float('inf'), 'sum(inf) = inf' + +# === NaN/Inf repr === +assert repr(np.array([float('nan')])) == 'array([nan])', 'nan repr lowercase' +assert repr(np.array([float('inf')])) == 'array([inf])', 'inf repr' +assert repr(np.array([float('-inf')])) == 'array([-inf])', '-inf repr' + +# === NaN comparisons === +r = np.array([float('nan')]) == np.array([float('nan')]) +assert r[0] == False, 'nan != nan' +r2 = np.array([float('nan')]) > 0 +assert r2[0] == False, 'nan > 0 is False' + +# === NaN in sort === +s = np.sort(np.array([float('nan'), 1.0, 2.0])) +assert s[0] == 1.0, 'sort nan: first elem' +assert s[1] == 2.0, 'sort nan: second elem' +assert math.isnan(s[2]), 'sort nan: nan last' + +s2 = np.sort(np.array([3.0, float('nan'), 1.0])) +assert s2[0] == 1.0, 'sort nan mid: first' +assert s2[1] == 3.0, 'sort nan mid: second' +assert math.isnan(s2[2]), 'sort nan mid: nan last' + + +# ============================================================ +# 16. EMPTY ARRAY EDGE CASES +# ============================================================ + +# === Empty array creation and attributes === +empty = np.array([]) +assert empty.shape == (0,), 'empty shape' +assert empty.dtype == 'float64', 'empty dtype' +assert len(empty) == 0, 'empty len' +assert empty.size == 0, 'empty size' +assert empty.ndim == 1, 'empty ndim' + +# === Empty array operations === +assert empty.tolist() == [], 'empty tolist' +assert empty.flatten().tolist() == [], 'empty flatten' +assert empty.cumsum().tolist() == [], 'empty cumsum' +assert empty.sum() == 0.0, 'empty sum' + +# mean of empty is nan (0/0) +assert math.isnan(empty.mean()), 'empty mean is nan' + +# std of empty is nan +assert math.isnan(empty.std()), 'empty std is nan' + +# sort/unique of empty +assert np.sort(empty).tolist() == [], 'empty sort' +assert np.unique(empty).tolist() == [], 'empty unique' + +# concatenate with empty +assert np.concatenate([empty, np.array([1.0, 2.0])]).tolist() == [1.0, 2.0], 'concat empty' + +# zeros/ones with 0 +assert np.zeros(0).tolist() == [], 'zeros(0)' +assert np.ones(0).tolist() == [], 'ones(0)' + + +# ============================================================ +# 17. DTYPE CORRECTNESS +# ============================================================ + +# === Division always produces float64 === +a = np.array([4, 6, 8]) +b = np.array([2, 3, 4]) +assert (a / b).dtype == 'float64', 'int / int -> float64' +assert (a / b).tolist() == [2.0, 2.0, 2.0], 'int / int values' +assert (a / 2).dtype == 'float64', 'int / scalar -> float64' +assert (a // b).dtype == 'int64', 'int // int -> int64' + +# === Arithmetic dtype promotion === +assert (a + b).dtype == 'int64', 'int + int -> int64' +assert (a * 2).dtype == 'int64', 'int * int_scalar -> int64' +assert (a * 1.0).dtype == 'float64', 'int * float_scalar -> float64' +assert (a + np.array([1.0, 2.0, 3.0])).dtype == 'float64', 'int + float -> float64' + +# === Comparison always produces bool === +assert (a > 5).dtype == 'bool', 'int > scalar -> bool' +assert (a == b).dtype == 'bool', 'int == int -> bool' + + +# ============================================================ +# 18. 2D ARRAY OPERATIONS +# ============================================================ + +# === 2D binary operations === +m1 = np.array([[1, 2], [3, 4]]) +m2 = np.array([[10, 20], [30, 40]]) +assert (m1 + m2).tolist() == [[11, 22], [33, 44]], '2d add' +assert (m1 * 2).tolist() == [[2, 4], [6, 8]], '2d scalar mul' +assert (2 * m1).tolist() == [[2, 4], [6, 8]], '2d scalar mul left' + +# === 2D comparisons === +assert (m1 > 2).tolist() == [[False, False], [True, True]], '2d > scalar' + +# === 2D methods === +assert m1.sum() == 10, '2d sum' +assert m1.mean() == 2.5, '2d mean' +assert m1.flatten().tolist() == [1, 2, 3, 4], '2d flatten' + +# === 2D tolist preserves nesting === +assert m1.tolist() == [[1, 2], [3, 4]], '2d tolist nested' +r = np.array([[1, 2, 3], [4, 5, 6]]) +assert r.tolist() == [[1, 2, 3], [4, 5, 6]], '2x3 tolist' + +# === Transpose === +assert r.T.shape == (3, 2), 'transpose shape' +assert r.T.tolist() == [[1, 4], [2, 5], [3, 6]], 'transpose values' + +# === 2D indexing === +assert r[0].tolist() == [1, 2, 3], '2d row 0' +assert r[1].tolist() == [4, 5, 6], '2d row 1' +assert r[-1].tolist() == [4, 5, 6], '2d row -1' +assert r[0][2] == 3, '2d chained index' +assert r[1][0] == 4, '2d chained index 2' + +# ============================================================ +# 19. TRIGONOMETRIC & MATH FUNCTIONS +# ============================================================ + +# === np.sin === +assert repr(np.sin(np.array([0.0]))) == 'array([0.])', 'sin(0)' +assert repr(np.cos(np.array([0.0]))) == 'array([1.])', 'cos(0)' +assert repr(np.tan(np.array([0.0]))) == 'array([0.])', 'tan(0)' + +# sin(pi/2) ~ 1 +sin_result = np.sin(np.array([0.0, math.pi / 2, math.pi])) +assert abs(sin_result[0]) < 1e-10, 'sin(0) ~ 0' +assert abs(sin_result[1] - 1.0) < 1e-10, 'sin(pi/2) ~ 1' + +# cos(0) = 1, cos(pi/2) ~ 0 +cos_result = np.cos(np.array([0.0, math.pi / 2, math.pi])) +assert abs(cos_result[0] - 1.0) < 1e-10, 'cos(0) ~ 1' +assert abs(cos_result[1]) < 1e-10, 'cos(pi/2) ~ 0' + +# sin on plain list +assert repr(np.sin([0.0])) == 'array([0.])', 'sin on list' + +# === np.log2 === +assert repr(np.log2(np.array([1.0, 2.0, 4.0, 8.0]))) == 'array([0., 1., 2., 3.])', 'log2' + +# === np.power === +assert repr(np.power(np.array([2, 3, 4]), 2)) == 'array([4, 9, 16])', 'power arr-scalar' +assert repr(np.power(2, np.array([1, 2, 3]))) == 'array([2, 4, 8])', 'power scalar-arr' + +# === np.diff === +assert repr(np.diff(np.array([1, 3, 6, 10]))) == 'array([2, 3, 4])', 'diff int' +assert repr(np.diff(np.array([1.0, 2.5, 4.0]))) == 'array([1.5, 1.5])', 'diff float' + +# ============================================================ +# 20. ARRAY CREATION EXPANSION +# ============================================================ + +# === np.full === +assert repr(np.full(3, 7)) == 'array([7, 7, 7])', 'full int' +assert repr(np.full(3, 7.0)) == 'array([7., 7., 7.])', 'full float' +assert repr(np.full(4, True)) == 'array([ True, True, True, True])', 'full bool' +assert np.full(3, 5).dtype == 'int64', 'full int dtype' +assert np.full(3, 5.0).dtype == 'float64', 'full float dtype' + +# === np.eye === +e = np.eye(3) +assert e.shape == (3, 3), 'eye shape' +assert e.dtype == 'float64', 'eye dtype' +assert e[0].tolist() == [1.0, 0.0, 0.0], 'eye row 0' +assert e[1].tolist() == [0.0, 1.0, 0.0], 'eye row 1' +assert e[2].tolist() == [0.0, 0.0, 1.0], 'eye row 2' + +# === np.copy === +orig = np.array([1, 2, 3]) +c = np.copy(orig) +assert repr(c) == 'array([1, 2, 3])', 'copy array' +assert repr(np.copy([4, 5, 6])) == 'array([4, 5, 6])', 'copy list' + +# === np.empty === +e = np.empty(3) +assert e.shape == (3,), 'empty shape' +assert e.dtype == 'float64', 'empty dtype' +assert len(e) == 3, 'empty len' + +# === np.zeros with tuple shape === +z = np.zeros((2, 3)) +assert z.shape == (2, 3), 'zeros tuple shape' +assert z.dtype == 'float64', 'zeros tuple dtype' +assert z.tolist() == [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], 'zeros tuple values' + +# === np.ones with tuple shape === +o = np.ones((2, 3)) +assert o.shape == (2, 3), 'ones tuple shape' +assert o.tolist() == [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], 'ones tuple values' + +# === np.zeros_like / np.ones_like === +a = np.array([1, 2, 3]) +z = np.zeros_like(a) +assert repr(z) == 'array([0, 0, 0])', 'zeros_like int' +assert z.dtype == 'int64', 'zeros_like preserves dtype' +o = np.ones_like(a) +assert repr(o) == 'array([1, 1, 1])', 'ones_like int' +b = np.array([1.0, 2.0]) +assert np.zeros_like(b).dtype == 'float64', 'zeros_like float dtype' +assert repr(np.ones_like(b)) == 'array([1., 1.])', 'ones_like float' + +# ============================================================ +# 21. TESTING & INSPECTION FUNCTIONS +# ============================================================ + +# === np.isnan, np.isinf, np.isfinite === +a = np.array([1.0, float('nan'), float('inf'), float('-inf'), 0.0]) +assert repr(np.isnan(a)) == 'array([False, True, False, False, False])', 'isnan' +assert repr(np.isinf(a)) == 'array([False, False, True, True, False])', 'isinf' +assert repr(np.isfinite(a)) == 'array([ True, False, False, False, True])', 'isfinite' +# Works on int arrays (always finite, never NaN) +assert repr(np.isnan(np.array([1, 2, 3]))) == 'array([False, False, False])', 'isnan int' +assert repr(np.isfinite(np.array([1, 2, 3]))) == 'array([ True, True, True])', 'isfinite int' + +# === np.array_equal === +assert np.array_equal(np.array([1, 2, 3]), np.array([1, 2, 3])) == True, 'array_equal true' +assert np.array_equal(np.array([1, 2, 3]), np.array([1, 2, 4])) == False, 'array_equal false' +assert np.array_equal(np.array([1, 2]), np.array([1, 2, 3])) == False, 'array_equal diff shape' + +# === np.count_nonzero === +assert np.count_nonzero(np.array([0, 1, 2, 0, 3])) == 3, 'count_nonzero' +assert np.count_nonzero(np.array([0.0, 0.0])) == 0, 'count_nonzero zeros' +assert np.count_nonzero(np.array([True, False, True])) == 2, 'count_nonzero bool' + +# === np.all / np.any (module-level) === +assert np.all(np.array([True, True, True])) == True, 'all true' +assert np.all(np.array([True, False, True])) == False, 'all false' +assert np.any(np.array([False, False, True])) == True, 'any true' +assert np.any(np.array([False, False, False])) == False, 'any false' +assert np.all(np.array([1, 2, 3])) == True, 'all int truthy' +assert np.any(np.array([0, 0, 0])) == False, 'any int all zero' +assert np.all([1, 1, 1]) == True, 'all on list' + +# ============================================================ +# 22. AGGREGATION EXPANSION +# ============================================================ + +# === prod === +assert np.prod(np.array([1, 2, 3, 4])) == 24, 'np.prod int' +assert np.array([2.0, 3.0, 4.0]).prod() == 24.0, 'arr.prod float' +assert np.prod(np.array([1.0])) == 1.0, 'prod single' +assert np.prod(np.zeros(0)) == 1.0, 'prod empty = 1.0' + +# === var === +a = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) +assert a.var() == np.var(a), 'var method == module' +assert abs(a.var() - 2.0) < 1e-10, 'var value' +assert abs(a.std() ** 2 - a.var()) < 1e-10, 'std^2 == var' + +# === median === +assert np.median(np.array([3, 1, 2])) == 2.0, 'median odd' +assert np.median(np.array([1, 2, 3, 4])) == 2.5, 'median even' +assert np.median(np.array([5.0])) == 5.0, 'median single' + +# === np.argmin / np.argmax (module-level) === +assert np.argmin(np.array([3, 1, 2])) == 1, 'np.argmin' +assert np.argmax(np.array([3, 1, 2])) == 0, 'np.argmax' +assert np.argmin([5, 2, 8]) == 1, 'np.argmin on list' + +# ============================================================ +# 23. ARRAY MANIPULATION +# ============================================================ + +# === np.reshape (module-level) === +a = np.arange(6) +b = np.reshape(a, (2, 3)) +assert b.shape == (2, 3), 'reshape mod shape' +assert b.tolist() == [[0, 1, 2], [3, 4, 5]], 'reshape mod values' + +# === np.transpose (module-level) === +a = np.array([[1, 2], [3, 4]]) +t = np.transpose(a) +assert t.tolist() == [[1, 3], [2, 4]], 'transpose mod' + +# === np.append === +assert repr(np.append(np.array([1, 2, 3]), np.array([4, 5]))) == 'array([1, 2, 3, 4, 5])', 'append arr-arr' +assert repr(np.append(np.array([1, 2]), [3, 4])) == 'array([1, 2, 3, 4])', 'append arr-list' + +# === np.vstack === +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +v = np.vstack([a, b]) +assert v.shape == (2, 3), 'vstack shape' +assert v.tolist() == [[1, 2, 3], [4, 5, 6]], 'vstack values' + +# === np.hstack === +h = np.hstack([a, b]) +assert repr(h) == 'array([1, 2, 3, 4, 5, 6])', 'hstack' + +# === np.stack === +s = np.stack([a, b]) +assert s.shape == (2, 3), 'stack shape' +assert s.tolist() == [[1, 2, 3], [4, 5, 6]], 'stack values' + +# === .ravel() === +a = np.array([[1, 2], [3, 4]]) +assert repr(a.ravel()) == 'array([1, 2, 3, 4])', 'ravel' + +# ============================================================ +# 24. SEARCH & INDEX FUNCTIONS +# ============================================================ + +# === np.nonzero === +idx = np.nonzero(np.array([0, 3, 0, 5, 0])) +assert len(idx) == 1, 'nonzero returns 1-tuple for 1d' +assert repr(idx[0]) == 'array([1, 3])', 'nonzero indices' + +# === np.argwhere === +result = np.argwhere(np.array([0, 3, 0, 5, 0])) +assert result.shape == (2, 1), 'argwhere shape' +assert result.flatten().tolist() == [1, 3], 'argwhere values' + +# === Fancy indexing with integer arrays === +a = np.array([10, 20, 30, 40, 50]) +idx = np.array([0, 2, 4]) +assert repr(a[idx]) == 'array([10, 30, 50])', 'fancy idx' +assert repr(a[np.array([4, 3, 2, 1, 0])]) == 'array([50, 40, 30, 20, 10])', 'fancy idx reverse' +matrix = np.array([[10, 11], [20, 21], [30, 31]]) +assert matrix[np.array([2, 0])].tolist() == [[30, 31], [10, 11]], 'fancy idx selects rows' +try: + a[np.array([0.0])] + assert False, 'expected float fancy index to fail' +except IndexError as exc: + assert str(exc) == 'arrays used as indices must be of integer (or boolean) type', 'float fancy idx error' + +# === Slice indexing === +a = np.array([10, 20, 30, 40, 50]) +assert repr(a[1:3]) == 'array([20, 30])', 'slice 1:3' +assert repr(a[::2]) == 'array([10, 30, 50])', 'slice ::2' +assert repr(a[::-1]) == 'array([50, 40, 30, 20, 10])', 'slice ::-1' +assert repr(a[:-1]) == 'array([10, 20, 30, 40])', 'slice :-1' +assert repr(a[2:]) == 'array([30, 40, 50])', 'slice 2:' + +# ============================================================ +# 25. REMAINING UTILITIES +# ============================================================ + +# === np.tile === +assert repr(np.tile(np.array([1, 2, 3]), 2)) == 'array([1, 2, 3, 1, 2, 3])', 'tile' + +# === np.repeat === +assert repr(np.repeat(np.array([1, 2, 3]), 2)) == 'array([1, 1, 2, 2, 3, 3])', 'repeat' + +# === np.split === +a = np.array([1, 2, 3, 4, 5, 6]) +parts = np.split(a, 3) +assert len(parts) == 3, 'split count' +assert repr(parts[0]) == 'array([1, 2])', 'split part 0' +assert repr(parts[1]) == 'array([3, 4])', 'split part 1' +assert repr(parts[2]) == 'array([5, 6])', 'split part 2' + +# split by indices +parts2 = np.split(a, [2, 4]) +assert repr(parts2[0]) == 'array([1, 2])', 'split idx part 0' +assert repr(parts2[1]) == 'array([3, 4])', 'split idx part 1' +assert repr(parts2[2]) == 'array([5, 6])', 'split idx part 2' + +# === .astype aliases === +a = np.array([1.5, 2.7, 3.1]) +assert repr(a.astype('int32')) == 'array([1, 2, 3])', 'astype int32' +assert repr(a.astype('float32')) == 'array([1.5, 2.7, 3.1])', 'astype float32' +assert repr(a.astype('int')) == 'array([1, 2, 3])', 'astype int' +assert repr(a.astype('float')) == 'array([1.5, 2.7, 3.1])', 'astype float' + +# ============================================================ +# 26. EDGE CASES FOR NEW FUNCTIONS +# ============================================================ + +# Empty array edge cases +assert repr(np.sin(np.zeros(0))) == 'array([], dtype=float64)', 'sin empty' +assert np.prod(np.array([1])) == 1, 'prod single' +assert np.count_nonzero(np.zeros(0)) == 0, 'count_nonzero empty' +assert repr(np.tile(np.array([1, 2]), 0)) == 'array([], dtype=int64)', 'tile 0 reps' +assert repr(np.repeat(np.zeros(0), 3)) == 'array([], dtype=float64)', 'repeat empty' +assert repr(np.diff(np.array([5]))) == 'array([], dtype=int64)', 'diff single element' +assert repr(np.full(0, 5)) == 'array([], dtype=int64)', 'full size 0' + +# ============================================================ +# 27. ADDITIONAL DTYPE AND OPERATION COVERAGE +# ============================================================ + +# === Bool array creation and dtype === +b = np.array([True, False, True]) +assert b.dtype == 'bool', 'bool array dtype' +assert b.tolist() == [True, False, True], 'bool array tolist' +assert repr(b) == 'array([ True, False, True])', 'bool array repr' +assert b.sum() == 2, 'bool sum' +assert b.any() == True, 'bool any' +assert b.all() == False, 'bool all' + +# Bool from comparison +c = np.array([1, 2, 3]) > 1 +assert c.dtype == 'bool', 'comparison produces bool' +assert c.sum() == 2, 'comparison bool sum' +assert c.tolist() == [False, True, True], 'comparison bool tolist' + +# === Additional math function edge cases === +# sin on int array +sin_int = np.sin(np.array([0, 1])) +assert sin_int.dtype == 'float64', 'sin int array -> float64' +assert sin_int[0] == 0.0, 'sin(0) exact' + +# cos on int array +cos_int = np.cos(np.array([0])) +assert cos_int[0] == 1.0, 'cos(0) exact' +assert cos_int.dtype == 'float64', 'cos int array -> float64' + +# log2 edge cases +assert np.log2(np.array([1.0]))[0] == 0.0, 'log2(1) = 0' +assert np.log2(np.array([2.0]))[0] == 1.0, 'log2(2) = 1' +assert np.log2(np.array([16.0]))[0] == 4.0, 'log2(16) = 4' + +# diff on larger array +assert np.diff(np.array([1, 1, 1, 1])).tolist() == [0, 0, 0], 'diff constant' +assert np.diff(np.array([0, 1, 4, 9])).tolist() == [1, 3, 5], 'diff quadratic' + +# power array-array +assert repr(np.power(np.array([2, 3]), np.array([3, 2]))) == 'array([8, 9])', 'power arr-arr' + +# === Additional aggregation edge cases === +# prod with negatives +assert np.prod(np.array([-1, 2, -3])) == 6, 'prod negatives' +assert np.prod(np.array([0, 1, 2])) == 0, 'prod with zero' + +# var of uniform array +assert np.var(np.array([5, 5, 5])) == 0.0, 'var uniform' +assert np.var(np.array([5.0])) == 0.0, 'var single' + +# median of larger arrays +assert np.median(np.array([1, 2, 3, 4, 5])) == 3.0, 'median 5 elements' +assert np.median(np.array([10, 20])) == 15.0, 'median 2 elements' + +# === count_nonzero more cases === +assert np.count_nonzero(np.array([1, 1, 1])) == 3, 'count_nonzero all nonzero' +assert np.count_nonzero(np.array([-1, 0, 1])) == 2, 'count_nonzero with neg' +assert np.count_nonzero(np.array([0.0, 0.1, 0.0])) == 1, 'count_nonzero float' + +# === Additional array_equal cases === +assert np.array_equal(np.array([1.0, 2.0]), np.array([1.0, 2.0])) == True, 'array_equal float' +assert np.array_equal(np.zeros(0), np.zeros(0)) == True, 'array_equal empty' +assert np.array_equal(np.array([1]), np.array([1.0])) == True, 'array_equal int vs float' + +# === Chained operations === +# Sort then slice +a = np.sort(np.array([5, 3, 1, 4, 2])) +assert a[0] == 1, 'sort then index first' +assert a[-1] == 5, 'sort then index last' +assert a[2] == 3, 'sort then index mid' + +# Arithmetic chains +a = np.array([1, 2, 3]) +assert ((a * 2) + 1).tolist() == [3, 5, 7], 'chain mul then add' +assert ((a + 1) * 2).tolist() == [4, 6, 8], 'chain add then mul' +assert (a * a).tolist() == [1, 4, 9], 'array self mul' +assert (a + a).tolist() == [2, 4, 6], 'array self add' + +# Comparison chains +assert (a > 1).sum() == 2, 'count elements > 1' +assert (a == 2).sum() == 1, 'count elements == 2' +assert (a < 4).all() == True, 'all < 4' +assert (a > 0).all() == True, 'all positive' +assert (a > 3).any() == False, 'none > 3' + +# === 2D operation coverage === +m = np.array([[1, 2, 3], [4, 5, 6]]) +assert m.min() == 1, '2d min' +assert m.max() == 6, '2d max' +assert m.size == 6, '2d size' +assert m.ndim == 2, '2d ndim' +assert m[0][0] == 1, '2d corner tl' +assert m[1][2] == 6, '2d corner br' +assert m[-1].tolist() == [4, 5, 6], '2d neg index row' + +# 2D arithmetic +m2 = m + 10 +assert m2.tolist() == [[11, 12, 13], [14, 15, 16]], '2d add scalar' +assert m2.shape == (2, 3), '2d add preserves shape' +m3 = m * m +assert m3.tolist() == [[1, 4, 9], [16, 25, 36]], '2d self mul' + +# === Repr edge cases === +assert repr(np.array([0])) == 'array([0])', 'repr zero int' +assert repr(np.array([0.0])) == 'array([0.])', 'repr zero float' +assert repr(np.array([-1])) == 'array([-1])', 'repr negative int' +assert repr(np.array([-1.5])) == 'array([-1.5])', 'repr negative float' +assert repr(np.array([True, True])) == 'array([ True, True])', 'repr bool all true' +assert repr(np.array([False])) == 'array([False])', 'repr bool single false' + +# === np.eye additional cases === +e1 = np.eye(1) +assert e1.tolist() == [[1.0]], 'eye 1x1' +e2 = np.eye(2) +assert e2.tolist() == [[1.0, 0.0], [0.0, 1.0]], 'eye 2x2' + +# === np.full additional cases === +assert np.full(1, 42).tolist() == [42], 'full single' +assert np.full(5, 0).tolist() == [0, 0, 0, 0, 0], 'full zeros' +assert np.full(3, -1).tolist() == [-1, -1, -1], 'full negative' +assert np.full(2, 3.14).tolist() == [3.14, 3.14], 'full pi' + +# === np.zeros_like / np.ones_like additional === +f = np.array([1.5, 2.5, 3.5]) +assert np.zeros_like(f).tolist() == [0.0, 0.0, 0.0], 'zeros_like float vals' +assert np.ones_like(f).tolist() == [1.0, 1.0, 1.0], 'ones_like float vals' + +# === np.isnan/isinf/isfinite on plain values === +assert np.isnan(np.array([0.0, 1.0])).tolist() == [False, False], 'isnan no nans' +assert np.isinf(np.array([0.0, 1.0])).tolist() == [False, False], 'isinf no infs' +assert np.isfinite(np.array([0.0, 1.0])).tolist() == [True, True], 'isfinite normal' + +# === np.split additional === +a = np.arange(10) +parts = np.split(a, 5) +assert len(parts) == 5, 'split into 5' +assert parts[0].tolist() == [0, 1], 'split 5 part 0' +assert parts[4].tolist() == [8, 9], 'split 5 part 4' + +# === np.tile additional === +assert np.tile(np.array([1, 2]), 3).tolist() == [1, 2, 1, 2, 1, 2], 'tile 3x' +assert np.tile(np.array([5]), 4).tolist() == [5, 5, 5, 5], 'tile single 4x' +assert np.tile(np.array([1, 2]), 1).tolist() == [1, 2], 'tile 1x identity' + +# === np.repeat additional === +assert np.repeat(np.array([1, 2, 3]), 1).tolist() == [1, 2, 3], 'repeat 1x identity' +assert np.repeat(np.array([5]), 3).tolist() == [5, 5, 5], 'repeat single 3x' +assert np.repeat(np.array([1, 2]), 3).tolist() == [1, 1, 1, 2, 2, 2], 'repeat 3x' + +# === Slice indexing additional === +a = np.arange(10) +assert a[2:5].tolist() == [2, 3, 4], 'slice 2:5' +assert a[:3].tolist() == [0, 1, 2], 'slice :3' +assert a[7:].tolist() == [7, 8, 9], 'slice 7:' +assert a[::3].tolist() == [0, 3, 6, 9], 'slice ::3' +assert a[1:7:2].tolist() == [1, 3, 5], 'slice 1:7:2' +assert a[::-2].tolist() == [9, 7, 5, 3, 1], 'slice ::-2' + +# === Fancy indexing additional === +a = np.arange(10) +idx = np.array([0, 5, 9]) +assert a[idx].tolist() == [0, 5, 9], 'fancy idx sparse' +idx2 = np.array([9, 0]) +assert a[idx2].tolist() == [9, 0], 'fancy idx reversed pair' + +# === np.nonzero / np.argwhere additional === +idx = np.nonzero(np.array([1, 0, 0, 1, 1])) +assert idx[0].tolist() == [0, 3, 4], 'nonzero more' +assert np.argwhere(np.array([0, 0, 0])).shape == (0, 1), 'argwhere all zero' +assert np.argwhere(np.array([1, 1, 1])).flatten().tolist() == [0, 1, 2], 'argwhere all nonzero' + +# ============================================================ +# 28. CONSTANTS AND TYPE OBJECTS +# ============================================================ + +# === np.pi === +assert abs(np.pi - 3.141592653589793) < 1e-15, 'np.pi value' +assert np.pi == math.pi, 'np.pi matches math.pi' + +# === np.e === +assert abs(np.e - 2.718281828459045) < 1e-15, 'np.e value' +assert np.e == math.e, 'np.e matches math.e' + +# === np.inf === +assert np.inf == math.inf, 'np.inf matches math.inf' +assert np.inf > 1e308, 'np.inf is large' +assert math.isinf(np.inf), 'np.inf is inf' + +# === np.nan === +assert np.nan != np.nan, 'np.nan is NaN (not equal to self)' +assert math.isnan(np.nan), 'np.nan is nan' + +# === np.newaxis === +assert np.newaxis is None, 'np.newaxis is None' + +# ============================================================ +# 29. INVERSE TRIGONOMETRIC FUNCTIONS +# ============================================================ + +# === np.arcsin === +a = np.array([0.0, 0.5, 1.0]) +r = np.arcsin(a) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'arcsin 0' +assert abs(r.tolist()[1] - 0.5235987755982988) < 1e-10, 'arcsin 0.5' +assert abs(r.tolist()[2] - 1.5707963267948966) < 1e-10, 'arcsin 1' + +# === np.arccos === +r = np.arccos(a) +assert abs(r.tolist()[0] - 1.5707963267948966) < 1e-10, 'arccos 0' +assert abs(r.tolist()[2] - 0.0) < 1e-10, 'arccos 1' + +# === np.arctan === +r = np.arctan(a) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'arctan 0' +assert abs(r.tolist()[2] - 0.7853981633974483) < 1e-10, 'arctan 1' + +# === np.arctan2 === +r = np.arctan2(np.array([1.0, 0.0]), np.array([1.0, 1.0])) +assert abs(r.tolist()[0] - 0.7853981633974483) < 1e-10, 'arctan2(1,1)' +assert abs(r.tolist()[1] - 0.0) < 1e-10, 'arctan2(0,1)' + +# ============================================================ +# 30. HYPERBOLIC FUNCTIONS +# ============================================================ + +# === np.sinh === +r = np.sinh(np.array([0.0, 1.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'sinh 0' +assert abs(r.tolist()[1] - 1.1752011936438014) < 1e-10, 'sinh 1' + +# === np.cosh === +r = np.cosh(np.array([0.0, 1.0])) +assert abs(r.tolist()[0] - 1.0) < 1e-10, 'cosh 0' +assert abs(r.tolist()[1] - 1.5430806348152437) < 1e-10, 'cosh 1' + +# === np.tanh === +r = np.tanh(np.array([0.0, 1.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'tanh 0' +assert abs(r.tolist()[1] - 0.7615941559557649) < 1e-10, 'tanh 1' + +# === np.arcsinh === +r = np.arcsinh(np.array([0.0, 1.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'arcsinh 0' +assert abs(r.tolist()[1] - 0.881373587019543) < 1e-10, 'arcsinh 1' + +# === np.arccosh === +r = np.arccosh(np.array([1.0, 2.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'arccosh 1' +assert abs(r.tolist()[1] - 1.3169578969248166) < 1e-10, 'arccosh 2' + +# === np.arctanh === +r = np.arctanh(np.array([0.0, 0.5])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'arctanh 0' +assert abs(r.tolist()[1] - 0.5493061443340549) < 1e-10, 'arctanh 0.5' + +# ============================================================ +# 31. REMAINING ELEMENT-WISE MATH +# ============================================================ + +# === np.sign === +assert np.sign(np.array([-3.0, 0.0, 5.0])).tolist() == [-1.0, 0.0, 1.0], 'sign' +assert np.sign(np.array([-1, 0, 1])).tolist() == [-1, 0, 1], 'sign int' + +# === np.square === +assert np.square(np.array([2.0, 3.0, 4.0])).tolist() == [4.0, 9.0, 16.0], 'square' +assert np.square(np.array([0.0, -2.0])).tolist() == [0.0, 4.0], 'square neg' + +# === np.cbrt === +assert np.cbrt(np.array([8.0, 27.0])).tolist() == [2.0, 3.0], 'cbrt' +assert abs(np.cbrt(np.array([1.0])).tolist()[0] - 1.0) < 1e-10, 'cbrt 1' + +# === np.reciprocal === +assert np.reciprocal(np.array([2.0, 4.0, 5.0])).tolist() == [0.5, 0.25, 0.2], 'reciprocal' + +# === np.log1p === +r = np.log1p(np.array([0.0, 1.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'log1p 0' +assert abs(r.tolist()[1] - 0.6931471805599453) < 1e-10, 'log1p 1' + +# === np.exp2 === +assert np.exp2(np.array([0.0, 1.0, 3.0])).tolist() == [1.0, 2.0, 8.0], 'exp2' + +# === np.expm1 === +r = np.expm1(np.array([0.0, 1.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'expm1 0' +assert abs(r.tolist()[1] - 1.7182818284590453) < 1e-10, 'expm1 1' + +# === np.deg2rad === +r = np.deg2rad(np.array([0.0, 90.0, 180.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'deg2rad 0' +assert abs(r.tolist()[1] - 1.5707963267948966) < 1e-10, 'deg2rad 90' +assert abs(r.tolist()[2] - 3.141592653589793) < 1e-10, 'deg2rad 180' + +# === np.rad2deg === +r = np.rad2deg(np.array([0.0, np.pi / 2, np.pi])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'rad2deg 0' +assert abs(r.tolist()[1] - 90.0) < 1e-10, 'rad2deg pi/2' +assert abs(r.tolist()[2] - 180.0) < 1e-10, 'rad2deg pi' + +# === np.degrees (alias for rad2deg) === +r = np.degrees(np.array([0.0, np.pi])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'degrees 0' +assert abs(r.tolist()[1] - 180.0) < 1e-10, 'degrees pi' + +# === np.radians (alias for deg2rad) === +r = np.radians(np.array([0.0, 180.0])) +assert abs(r.tolist()[0] - 0.0) < 1e-10, 'radians 0' +assert abs(r.tolist()[1] - 3.141592653589793) < 1e-10, 'radians 180' + +# === np.hypot === +assert np.hypot(np.array([3.0]), np.array([4.0])).tolist() == [5.0], 'hypot 3-4-5' +assert np.hypot(np.array([0.0]), np.array([5.0])).tolist() == [5.0], 'hypot 0-5' + +# === np.nan_to_num === +a = np.array([1.0, float('nan'), float('inf'), float('-inf')]) +r = np.nan_to_num(a) +assert r.tolist()[0] == 1.0, 'nan_to_num keeps normal' +assert r.tolist()[1] == 0.0, 'nan_to_num replaces nan' + +# === np.fmin / np.fmax === +assert np.fmin(np.array([1.0, 3.0]), np.array([2.0, 1.0])).tolist() == [1.0, 1.0], 'fmin basic' +assert np.fmax(np.array([1.0, 3.0]), np.array([2.0, 1.0])).tolist() == [2.0, 3.0], 'fmax basic' +# fmin/fmax ignore NaN +a = np.array([1.0, float('nan')]) +b = np.array([2.0, 3.0]) +assert np.fmin(a, b).tolist()[0] == 1.0, 'fmin nan ignore 1' +assert np.fmin(a, b).tolist()[1] == 3.0, 'fmin nan ignore 2' +assert np.fmax(a, b).tolist()[0] == 2.0, 'fmax nan ignore 1' +assert np.fmax(a, b).tolist()[1] == 3.0, 'fmax nan ignore 2' + +# === np.fmod === +assert np.fmod(np.array([5.0, 7.0]), np.array([3.0, 2.0])).tolist() == [2.0, 1.0], 'fmod' + +# === np.rint === +assert np.rint(np.array([1.5, 2.5, 3.5])).tolist() == [2.0, 2.0, 4.0], 'rint banker' +assert np.rint(np.array([0.5, 1.5])).tolist() == [0.0, 2.0], 'rint half even' + +# === np.fabs === +assert np.fabs(np.array([-1.0, 2.0, -3.0])).tolist() == [1.0, 2.0, 3.0], 'fabs' + +# === np.positive / np.negative === +assert np.positive(np.array([-1.0, 2.0])).tolist() == [-1.0, 2.0], 'positive' +assert np.negative(np.array([-1.0, 2.0])).tolist() == [1.0, -2.0], 'negative' + +# ============================================================ +# 32. NAN-AWARE AGGREGATIONS +# ============================================================ + +a_nan = np.array([1.0, float('nan'), 3.0, float('nan'), 5.0]) + +# === np.nansum === +assert np.nansum(a_nan) == 9.0, 'nansum' +assert np.nansum(np.array([1.0, 2.0, 3.0])) == 6.0, 'nansum no nan' + +# === np.nanmean === +assert np.nanmean(a_nan) == 3.0, 'nanmean' + +# === np.nanmin === +assert np.nanmin(a_nan) == 1.0, 'nanmin' + +# === np.nanmax === +assert np.nanmax(a_nan) == 5.0, 'nanmax' + +# === np.nanstd === +assert abs(np.nanstd(a_nan) - 1.632993161855452) < 1e-10, 'nanstd' + +# === np.nanvar === +assert abs(np.nanvar(a_nan) - 2.6666666666666665) < 1e-10, 'nanvar' + +# === np.nanprod === +assert np.nanprod(a_nan) == 15.0, 'nanprod' + +# === np.nanmedian === +assert np.nanmedian(a_nan) == 3.0, 'nanmedian' + +# === np.nanargmin === +assert np.nanargmin(a_nan) == 0, 'nanargmin' + +# === np.nanargmax === +assert np.nanargmax(a_nan) == 4, 'nanargmax' + +# === np.nancumsum === +assert np.nancumsum(a_nan).tolist() == [1.0, 1.0, 4.0, 4.0, 9.0], 'nancumsum' + +# === np.nancumprod === +assert np.nancumprod(a_nan).tolist() == [1.0, 1.0, 3.0, 3.0, 15.0], 'nancumprod' + +# ============================================================ +# 33. ADDITIONAL STATISTICS +# ============================================================ + +# === np.ptp === +assert np.ptp(np.array([3.0, 1.0, 5.0, 2.0])) == 4.0, 'ptp' +assert np.ptp(np.array([7.0])) == 0.0, 'ptp single' + +# === np.cumprod === +assert np.cumprod(np.array([1.0, 2.0, 3.0, 4.0])).tolist() == [1.0, 2.0, 6.0, 24.0], 'cumprod' +assert np.cumprod(np.array([5.0])).tolist() == [5.0], 'cumprod single' + +# === np.percentile === +assert np.percentile(np.array([1.0, 2.0, 3.0, 4.0]), 50) == 2.5, 'percentile 50' +assert np.percentile(np.array([1.0, 2.0, 3.0, 4.0]), 0) == 1.0, 'percentile 0' +assert np.percentile(np.array([1.0, 2.0, 3.0, 4.0]), 100) == 4.0, 'percentile 100' + +# === np.quantile === +assert np.quantile(np.array([1.0, 2.0, 3.0, 4.0]), 0.5) == 2.5, 'quantile 0.5' +assert np.quantile(np.array([1.0, 2.0, 3.0, 4.0]), 0.0) == 1.0, 'quantile 0' +assert np.quantile(np.array([1.0, 2.0, 3.0, 4.0]), 1.0) == 4.0, 'quantile 1' + +# === np.average === +assert np.average(np.array([1.0, 2.0, 3.0])) == 2.0, 'average' + +# ============================================================ +# 34. LOGICAL AND TESTING FUNCTIONS +# ============================================================ + +# === np.logical_and === +assert np.logical_and(np.array([1, 1, 0]), np.array([1, 0, 0])).tolist() == [True, False, False], 'logical_and' + +# === np.logical_or === +assert np.logical_or(np.array([1, 1, 0]), np.array([1, 0, 0])).tolist() == [True, True, False], 'logical_or' + +# === np.logical_not === +assert np.logical_not(np.array([1, 0, 1])).tolist() == [False, True, False], 'logical_not' + +# === np.logical_xor === +assert np.logical_xor(np.array([1, 1, 0]), np.array([1, 0, 0])).tolist() == [False, True, False], 'logical_xor' + +# === np.allclose === +assert np.allclose(np.array([1.0, 2.0]), np.array([1.0, 2.0])) == True, 'allclose exact' +assert np.allclose(np.array([1.0, 2.0]), np.array([1.0, 2.1])) == False, 'allclose not close' +assert np.allclose(np.array([1.0]), np.array([1.0 + 1e-9])) == True, 'allclose within tol' + +# === np.isclose === +r = np.isclose(np.array([1.0, 2.0]), np.array([1.0, 2.1])) +assert r.tolist() == [True, False], 'isclose basic' + +# === np.isin === +r = np.isin(np.array([1.0, 2.0, 3.0, 4.0]), np.array([2.0, 4.0])) +assert r.tolist() == [False, True, False, True], 'isin' + +# ============================================================ +# 35. ARRAY MANIPULATION +# ============================================================ + +# === np.flip === +assert np.flip(np.array([1.0, 2.0, 3.0])).tolist() == [3.0, 2.0, 1.0], 'flip 1d' +assert np.flip(np.array([1, 2, 3])).tolist() == [3, 2, 1], 'flip 1d int' + +# === np.fliplr === +a2d = np.array([[1.0, 2.0], [3.0, 4.0]]) +assert np.fliplr(a2d).tolist() == [[2.0, 1.0], [4.0, 3.0]], 'fliplr' + +# === np.flipud === +assert np.flipud(a2d).tolist() == [[3.0, 4.0], [1.0, 2.0]], 'flipud' + +# === np.roll === +assert np.roll(np.array([1.0, 2.0, 3.0, 4.0]), 2).tolist() == [3.0, 4.0, 1.0, 2.0], 'roll +2' +assert np.roll(np.array([1.0, 2.0, 3.0, 4.0]), -1).tolist() == [2.0, 3.0, 4.0, 1.0], 'roll -1' +assert np.roll(np.array([1.0, 2.0, 3.0]), 0).tolist() == [1.0, 2.0, 3.0], 'roll 0' + +# === np.expand_dims === +a = np.array([1.0, 2.0, 3.0]) +assert np.expand_dims(a, 0).shape == (1, 3), 'expand_dims axis=0' +assert np.expand_dims(a, 1).shape == (3, 1), 'expand_dims axis=1' + +# === np.squeeze === +a = np.array([[[1.0], [2.0]]]) +assert np.squeeze(a).shape == (2,), 'squeeze removes 1-dims' +a2 = np.array([[1.0, 2.0], [3.0, 4.0]]) +assert np.squeeze(a2).shape == (2, 2), 'squeeze no effect' + +# === np.ravel (module-level) === +assert np.ravel(np.array([[1.0, 2.0], [3.0, 4.0]])).tolist() == [1.0, 2.0, 3.0, 4.0], 'ravel 2d' + +# === np.delete === +assert np.delete(np.array([1.0, 2.0, 3.0, 4.0]), 1).tolist() == [1.0, 3.0, 4.0], 'delete idx 1' +assert np.delete(np.array([1.0, 2.0, 3.0]), 0).tolist() == [2.0, 3.0], 'delete idx 0' +assert np.delete(np.array([1.0, 2.0, 3.0]), 2).tolist() == [1.0, 2.0], 'delete last' + +# === np.insert === +assert np.insert(np.array([1.0, 2.0, 4.0]), 2, 3.0).tolist() == [1.0, 2.0, 3.0, 4.0], 'insert middle' +assert np.insert(np.array([2.0, 3.0]), 0, 1.0).tolist() == [1.0, 2.0, 3.0], 'insert front' + +# === np.diag === +# 1D → 2D diagonal matrix +assert np.diag(np.array([1.0, 2.0, 3.0])).tolist() == [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]], 'diag 1d->2d' +# 2D → 1D diagonal extraction +assert np.diag(np.array([[1.0, 2.0], [3.0, 4.0]])).tolist() == [1.0, 4.0], 'diag 2d->1d' + +# === np.diagonal === +assert np.diagonal(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).tolist() == [1.0, 5.0], 'diagonal rect' +assert np.diagonal(np.array([[1.0, 2.0], [3.0, 4.0]])).tolist() == [1.0, 4.0], 'diagonal square' + +# === np.trace === +assert np.trace(np.array([[1.0, 2.0], [3.0, 4.0]])) == 5.0, 'trace 2x2' +assert np.trace(np.array([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]])) == 6.0, 'trace 3x3' + +# === np.flatnonzero === +assert np.flatnonzero(np.array([0.0, 1.0, 0.0, 3.0])).tolist() == [1, 3], 'flatnonzero' +assert np.flatnonzero(np.array([0.0, 0.0])).tolist() == [], 'flatnonzero all zero' +assert np.flatnonzero(np.array([1.0, 2.0])).tolist() == [0, 1], 'flatnonzero all nonzero' + +# === np.asarray === +r = np.asarray([1.0, 2.0, 3.0]) +assert r.tolist() == [1.0, 2.0, 3.0], 'asarray from list' + +# === np.column_stack === +a = np.array([1.0, 2.0]) +b = np.array([3.0, 4.0]) +r = np.column_stack([a, b]) +assert r.tolist() == [[1.0, 3.0], [2.0, 4.0]], 'column_stack' + +# === np.array_split === +parts = np.array_split(np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 3) +assert parts[0].tolist() == [1.0, 2.0], 'array_split part 0' +assert parts[1].tolist() == [3.0, 4.0], 'array_split part 1' +assert parts[2].tolist() == [5.0], 'array_split part 2' + +# === np.full_like === +a = np.array([1.0, 2.0, 3.0]) +assert np.full_like(a, 7.0).tolist() == [7.0, 7.0, 7.0], 'full_like' +assert np.full_like(a, 7.0).shape == (3,), 'full_like shape' + +# ============================================================ +# 36. SORTING, SEARCHING, SET OPERATIONS +# ============================================================ + +# === np.argsort (module-level) === +assert np.argsort(np.array([3.0, 1.0, 2.0])).tolist() == [1, 2, 0], 'argsort mod' + +# === np.searchsorted === +assert np.searchsorted(np.array([1.0, 3.0, 5.0, 7.0]), 4.0) == 2, 'searchsorted' +assert np.searchsorted(np.array([1.0, 3.0, 5.0]), 0.0) == 0, 'searchsorted left' +assert np.searchsorted(np.array([1.0, 3.0, 5.0]), 6.0) == 3, 'searchsorted right' + +# === np.extract === +cond = np.array([1, 0, 1, 0]) +arr = np.array([10.0, 20.0, 30.0, 40.0]) +assert np.extract(cond, arr).tolist() == [10.0, 30.0], 'extract' +assert np.extract(np.array([0, 0]), np.array([1.0, 2.0])).tolist() == [], 'extract none' + +# === np.intersect1d === +assert np.intersect1d(np.array([1.0, 2.0, 3.0]), np.array([2.0, 3.0, 4.0])).tolist() == [2.0, 3.0], 'intersect1d' + +# === np.union1d === +assert np.union1d(np.array([1.0, 2.0]), np.array([2.0, 3.0])).tolist() == [1.0, 2.0, 3.0], 'union1d' + +# === np.setdiff1d === +assert np.setdiff1d(np.array([1.0, 2.0, 3.0]), np.array([2.0])).tolist() == [1.0, 3.0], 'setdiff1d' + +# === np.setxor1d === +assert np.setxor1d(np.array([1.0, 2.0, 3.0]), np.array([2.0, 3.0, 4.0])).tolist() == [1.0, 4.0], 'setxor1d' + +# === np.bincount === +assert np.bincount(np.array([0, 1, 1, 2, 2, 2])).tolist() == [1, 2, 3], 'bincount' +assert np.bincount(np.array([3])).tolist() == [0, 0, 0, 1], 'bincount sparse' + +# === np.digitize === +assert np.digitize(np.array([0.5, 1.5, 2.5, 3.5]), np.array([1.0, 2.0, 3.0])).tolist() == [0, 1, 2, 3], 'digitize' + +# ============================================================ +# 37. LINEAR ALGEBRA BASICS +# ============================================================ + +# === np.outer === +assert np.outer(np.array([1.0, 2.0]), np.array([3.0, 4.0])).tolist() == [[3.0, 4.0], [6.0, 8.0]], 'outer product' +assert np.outer(np.array([1.0]), np.array([2.0, 3.0])).tolist() == [[2.0, 3.0]], 'outer 1x2' + +# === np.cross === +assert np.cross(np.array([1.0, 0.0, 0.0]), np.array([0.0, 1.0, 0.0])).tolist() == [0.0, 0.0, 1.0], 'cross i x j' +assert np.cross(np.array([0.0, 1.0, 0.0]), np.array([0.0, 0.0, 1.0])).tolist() == [1.0, 0.0, 0.0], 'cross j x k' + +# ============================================================ +# 38. CREATION FUNCTIONS (ADVANCED) +# ============================================================ + +# === np.logspace === +r = np.logspace(0, 2, 3) +assert abs(r.tolist()[0] - 1.0) < 1e-10, 'logspace start' +assert abs(r.tolist()[1] - 10.0) < 1e-10, 'logspace mid' +assert abs(r.tolist()[2] - 100.0) < 1e-10, 'logspace end' + +# === np.geomspace === +r = np.geomspace(1, 100, 3) +assert abs(r.tolist()[0] - 1.0) < 1e-10, 'geomspace start' +assert abs(r.tolist()[1] - 10.0) < 1e-10, 'geomspace mid' +assert abs(r.tolist()[2] - 100.0) < 1e-10, 'geomspace end' + +# === np.tri === +assert np.tri(3).tolist() == [[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]], 'tri 3x3' + +# === np.tril === +m = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) +assert np.tril(m).tolist() == [[1.0, 0.0, 0.0], [4.0, 5.0, 0.0], [7.0, 8.0, 9.0]], 'tril' + +# === np.triu === +assert np.triu(m).tolist() == [[1.0, 2.0, 3.0], [0.0, 5.0, 6.0], [0.0, 0.0, 9.0]], 'triu' + +# === np.identity === +assert np.identity(3).tolist() == [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], 'identity 3' +assert np.identity(1).tolist() == [[1.0]], 'identity 1' + +# === np.meshgrid === +x, y = np.meshgrid(np.array([1.0, 2.0, 3.0]), np.array([4.0, 5.0])) +assert x.tolist() == [[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]], 'meshgrid x' +assert y.tolist() == [[4.0, 4.0, 4.0], [5.0, 5.0, 5.0]], 'meshgrid y' + +# === np.gradient === +r = np.gradient(np.array([1.0, 3.0, 6.0, 10.0])) +assert r.tolist() == [2.0, 2.5, 3.5, 4.0], 'gradient' +r2 = np.gradient(np.array([1.0, 4.0])) +assert r2.tolist() == [3.0, 3.0], 'gradient 2-elem' + +# === np.convolve === +r = np.convolve(np.array([1.0, 2.0, 3.0]), np.array([0.0, 1.0, 0.5])) +assert r.tolist() == [0.0, 1.0, 2.5, 4.0, 1.5], 'convolve full' + +# === np.interp === +xp = np.array([1.0, 2.0, 3.0]) +fp = np.array([10.0, 20.0, 30.0]) +assert np.interp(np.array([1.5, 2.5]), xp, fp).tolist() == [15.0, 25.0], 'interp' +assert np.interp(np.array([0.0]), xp, fp).tolist() == [10.0], 'interp left clamp' +assert np.interp(np.array([4.0]), xp, fp).tolist() == [30.0], 'interp right clamp' + +# === np.select === +c1 = np.array([1, 0, 0]) +c2 = np.array([0, 1, 0]) +ch1 = np.array([10.0, 20.0, 30.0]) +ch2 = np.array([40.0, 50.0, 60.0]) +r = np.select([c1, c2], [ch1, ch2], 0.0) +assert r.tolist() == [10.0, 50.0, 0.0], 'select' + +# ============================================================ +# 39. EMPTY-LIKE / CORRELATE +# ============================================================ + +# === np.empty_like === +a = np.array([1.0, 2.0, 3.0]) +r = np.empty_like(a) +assert r.shape == (3,), 'empty_like shape' +assert r.dtype == a.dtype, 'empty_like dtype' + +# === np.correlate === +a = np.array([1.0, 2.0, 3.0]) +v = np.array([0.0, 1.0, 0.5]) +r = np.correlate(a, v) +assert len(r.tolist()) > 0, 'correlate produces output' + +# ============================================================ +# PHASE 2: CRITICAL MISSING OPERATORS +# ============================================================ + +# === 2.1 Bitwise operators: &, |, ^ on ndarray === + +# Bool arrays — element-wise logical AND/OR/XOR +a = np.array([1, 0, 1, 0]) +b = np.array([1, 1, 0, 0]) +ba = a.astype('bool') +bb = b.astype('bool') + +r = ba & bb +assert r.tolist() == [True, False, False, False], 'bool array & (AND)' +assert r.dtype == 'bool', 'bool & dtype' + +r = ba | bb +assert r.tolist() == [True, True, True, False], 'bool array | (OR)' +assert r.dtype == 'bool', 'bool | dtype' + +r = ba ^ bb +assert r.tolist() == [False, True, True, False], 'bool array ^ (XOR)' +assert r.dtype == 'bool', 'bool ^ dtype' + +# Int arrays — bitwise AND/OR/XOR +a = np.array([5, 3, 12, 10]) +b = np.array([3, 6, 10, 15]) + +r = a & b +assert r.tolist() == [1, 2, 8, 10], 'int array & (AND)' +assert r.dtype == 'int64', 'int & dtype' + +r = a | b +assert r.tolist() == [7, 7, 14, 15], 'int array | (OR)' +assert r.dtype == 'int64', 'int | dtype' + +r = a ^ b +assert r.tolist() == [6, 5, 6, 5], 'int array ^ (XOR)' +assert r.dtype == 'int64', 'int ^ dtype' + +# Scalar bitwise operations +a = np.array([5, 3, 12]) +r = a & 3 +assert r.tolist() == [1, 3, 0], 'int array & scalar' + +r = a | 2 +assert r.tolist() == [7, 3, 14], 'int array | scalar' + +r = a ^ 6 +assert r.tolist() == [3, 5, 10], 'int array ^ scalar' + +# Scalar on left +r = 3 & np.array([5, 3, 12]) +assert r.tolist() == [1, 3, 0], 'scalar & int array' + +# === 2.2 __setitem__ — arr[i] = val, arr[mask] = val, arr[slice] = val === + +# arr[i] = val — set single element by int index +a = np.array([1, 2, 3, 4, 5]) +a[0] = 10 +assert a.tolist() == [10, 2, 3, 4, 5], 'setitem int index 0' + +a[4] = 50 +assert a.tolist() == [10, 2, 3, 4, 50], 'setitem int index 4' + +a[-1] = 99 +assert a.tolist() == [10, 2, 3, 4, 99], 'setitem negative index' + +# arr[mask] = val — set elements where bool mask is True +a = np.array([1, 2, 3, 4, 5]) +mask = np.array([1, 0, 1, 0, 1]).astype('bool') +a[mask] = 0 +assert a.tolist() == [0, 2, 0, 4, 0], 'setitem bool mask' + +# arr[slice] = val — set slice of elements +a = np.array([1, 2, 3, 4, 5]) +a[1:4] = 99 +assert a.tolist() == [1, 99, 99, 99, 5], 'setitem slice' + +a = np.array([1, 2, 3, 4, 5]) +a[::2] = 0 +assert a.tolist() == [0, 2, 0, 4, 0], 'setitem slice step 2' + +a = np.array([0, 1, 2, 3]) +a[1:3] = np.array([9]) +assert a.tolist() == [0, 9, 9, 3], 'setitem slice broadcasts single-element array' + +a = np.array([0, 1, 2, 3]) +a[3:0:-1] = np.array([9]) +assert a.tolist() == [0, 9, 9, 9], 'setitem negative slice broadcasts single-element array' + +a = np.array([0, 1, 2, 3]) +a[1:1] = np.array([9]) +assert a.tolist() == [0, 1, 2, 3], 'setitem empty slice with ndarray is no-op' + +try: + a[0] = 'bad' + assert False, 'expected non-numeric setitem to fail' +except TypeError as exc: + assert str(exc) == 'ndarray numeric argument must be int, float, or bool', 'setitem rejects non-numeric value' + +try: + a[1:3] = np.array([7, 8, 9]) + assert False, 'expected longer ndarray slice assignment to fail' +except ValueError as exc: + assert str(exc) == 'could not broadcast input array from shape (3,) into shape (2,)', 'setitem long slice rhs error' + +try: + a[1:3] = np.array([]) + assert False, 'expected empty ndarray slice assignment into non-empty target to fail' +except ValueError as exc: + assert str(exc) == 'could not broadcast input array from shape (0,) into shape (2,)', ( + 'setitem empty slice rhs error' + ) + +# === 2.3 __iter__ — for x in arr === + +# 1D array yields scalars +a = np.array([10, 20, 30]) +items = [] +for x in a: + items.append(x) +assert items == [10, 20, 30], 'iter 1D yields scalars' + +# Float array +a = np.array([1.5, 2.5, 3.5]) +items = [] +for x in a: + items.append(x) +assert items == [1.5, 2.5, 3.5], 'iter 1D float yields scalars' + +# list() conversion via iter +a = np.array([1, 2, 3]) +assert list(a) == [1, 2, 3], 'list(ndarray) via iter' + +try: + list(np.array(7)) + assert False, 'expected iterating 0d ndarray to fail' +except TypeError as exc: + assert str(exc) == 'iteration over a 0-d array', '0d ndarray iteration error' + +# === 2.4 __contains__ — val in arr === + +a = np.array([1, 2, 3, 4, 5]) +assert 3 in a, 'int in array' +assert 6 not in a, 'int not in array' +assert 1.0 in np.array([1.0, 2.0, 3.0]), 'float in array' +assert 9223372036854775808 in np.array([9.223372036854776e18]), 'LongInt in float array' +assert 9223372036854775808 not in np.array([1.0]), 'LongInt not in unrelated float array' + +# Bool check +a = np.array([0, 1, 0]) +assert 1 in a, '1 in array with zeros' +assert 2 not in a, '2 not in array' + +# === 2.5 In-place operators: +=, -=, *=, /= === + +# += scalar +a = np.array([1, 2, 3]) +a += 10 +assert a.tolist() == [11, 12, 13], 'iadd scalar' + +try: + a = np.array([1, 2, 3]) + a += 0.5 + assert False, 'expected int iadd float scalar to fail' +except TypeError as exc: + assert str(exc) == ( + "Cannot cast ufunc 'add' output from dtype('float64') to dtype('int64') with casting rule 'same_kind'" + ), 'iadd float scalar cast error' + +a = np.array([1.0, 2.0, 3.0]) +a += 0.5 +assert a.tolist() == [1.5, 2.5, 3.5], 'float iadd float scalar values' +assert a.dtype == 'float64', 'float iadd float scalar dtype' + +# += array +a = np.array([1, 2, 3]) +b = np.array([10, 20, 30]) +a += b +assert a.tolist() == [11, 22, 33], 'iadd array' + +try: + a = np.array([1, 2, 3]) + a += np.array([0.5, 0.5, 0.5]) + assert False, 'expected int iadd float array to fail' +except TypeError as exc: + assert str(exc) == ( + "Cannot cast ufunc 'add' output from dtype('float64') to dtype('int64') with casting rule 'same_kind'" + ), 'iadd float array cast error' + +try: + a = np.array([[1, 2], [3, 4]]) + a += np.array([10, 20, 30, 40]) + assert False, 'expected same-length different-shape iadd to fail' +except TypeError as exc: + assert str(exc) == "unsupported operand type(s) for +=: 'numpy.ndarray' and 'numpy.ndarray'", ( + 'iadd shape mismatch error' + ) + +# -= scalar +a = np.array([10, 20, 30]) +a -= 5 +assert a.tolist() == [5, 15, 25], 'isub scalar' + +# -= array +a = np.array([10, 20, 30]) +b = np.array([1, 2, 3]) +a -= b +assert a.tolist() == [9, 18, 27], 'isub array' + +# *= scalar +a = np.array([1, 2, 3]) +a *= 5 +assert a.tolist() == [5, 10, 15], 'imul scalar' + +# *= array +a = np.array([2, 3, 4]) +b = np.array([10, 20, 30]) +a *= b +assert a.tolist() == [20, 60, 120], 'imul array' + +# /= scalar +a = np.array([10.0, 20.0, 30.0]) +a /= 2 +assert a.tolist() == [5.0, 10.0, 15.0], 'idiv scalar' + +# /= array +a = np.array([10.0, 20.0, 30.0]) +b = np.array([2.0, 4.0, 5.0]) +a /= b +assert a.tolist() == [5.0, 5.0, 6.0], 'idiv array' + +# === 2.6 @ operator (matmul) and np.matmul === + +# 1D @ 1D — dot product +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +r = a @ b +assert r == 32, '1D @ 1D dot product' + +# 2D @ 2D — matrix multiplication +a = np.array([[1, 2], [3, 4]]) +b = np.array([[5, 6], [7, 8]]) +r = a @ b +assert r.tolist() == [[19, 22], [43, 50]], '2D @ 2D matmul' + +# 2D @ 1D — matrix-vector +a = np.array([[1, 2], [3, 4]]) +b = np.array([5, 6]) +r = a @ b +assert r.tolist() == [17, 39], '2D @ 1D matvec' + +# 1D @ 2D — vector-matrix +a = np.array([1, 2]) +b = np.array([[3, 4], [5, 6]]) +r = a @ b +assert r.tolist() == [13, 16], '1D @ 2D vecmat' + +# np.matmul function +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +r = np.matmul(a, b) +assert r == 32, 'np.matmul 1D dot product' + +# np.matmul 2D +a = np.array([[1, 0], [0, 1]]) +b = np.array([[5, 6], [7, 8]]) +r = np.matmul(a, b) +assert r.tolist() == [[5, 6], [7, 8]], 'np.matmul identity 2D' + +# === ndarray.item() === +a = np.array([42]) +assert a.item() == 42, 'item int scalar' +a = np.array([3.14]) +assert a.item() == 3.14, 'item float scalar' + +# === ndarray.cumprod() === +a = np.array([1, 2, 3, 4]) +assert a.cumprod().tolist() == [1, 2, 6, 24], 'cumprod int' +a = np.array([1.0, 2.0, 3.0]) +assert a.cumprod().tolist() == [1.0, 2.0, 6.0], 'cumprod float' + +# === ndarray.squeeze() === +a = np.array([[1, 2, 3]]) +assert a.squeeze().tolist() == [1, 2, 3], 'squeeze removes unit dim' +assert a.squeeze().shape == (3,), 'squeeze shape' + +# === ndarray.take() === +a = np.array([10, 20, 30, 40, 50]) +idx = np.array([0, 2, 4]) +assert a.take(idx).tolist() == [10, 30, 50], 'take indices' + +# === ndarray.diagonal() === +a = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +assert a.diagonal().tolist() == [1, 5, 9], 'diagonal 3x3' +a = np.array([[1, 2], [3, 4]]) +assert a.diagonal().tolist() == [1, 4], 'diagonal 2x2' + +# === ndarray.trace() === +a = np.array([[1, 2], [3, 4]]) +assert a.trace() == 5, 'trace int 2x2' +a = np.array([[1.0, 2.0], [3.0, 4.0]]) +assert a.trace() == 5.0, 'trace float 2x2' + +# === ndarray.fill() === +a = np.array([1.0, 2.0, 3.0]) +a.fill(7.0) +assert a.tolist() == [7.0, 7.0, 7.0], 'fill float' + +# === ndarray.compress() === +a = np.array([10, 20, 30, 40]) +c = np.array([1, 0, 1, 0]) +assert a.compress(c).tolist() == [10, 30], 'compress bool-like' + +# === ndarray.swapaxes() === +a = np.array([[1, 2], [3, 4]]) +assert a.swapaxes(0, 1).tolist() == [[1, 3], [2, 4]], 'swapaxes 2D transpose' + +# === ndarray.nbytes === +a = np.array([1.0, 2.0, 3.0]) +assert a.nbytes == 24, 'nbytes 3 floats' + +# === ndarray.itemsize === +a = np.array([1.0, 2.0, 3.0]) +assert a.itemsize == 8, 'itemsize float64' + +# === np.row_stack (alias for vstack) === +a = np.array([1, 2]) +b = np.array([3, 4]) +assert np.row_stack([a, b]).tolist() == [[1, 2], [3, 4]], 'row_stack' + +# === np.hsplit === +a = np.array([1, 2, 3, 4, 5, 6]) +parts = np.hsplit(a, 3) +assert parts[0].tolist() == [1, 2], 'hsplit part 0' +assert parts[1].tolist() == [3, 4], 'hsplit part 1' +assert parts[2].tolist() == [5, 6], 'hsplit part 2' + +# === np.vsplit === +a = np.array([1, 2, 3, 4]) +parts = np.vsplit(a, 2) +assert parts[0].tolist() == [1, 2], 'vsplit part 0' +assert parts[1].tolist() == [3, 4], 'vsplit part 1' + +# === np.inner === +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +assert np.inner(a, b) == 32, 'inner product' + +# === np.vdot === +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +assert np.vdot(a, b) == 32, 'vdot product' + +# ============================================================ +# 42. COMPREHENSIVE EDGE CASES AND ADDITIONAL COVERAGE +# ============================================================ + +# === Operator edge cases === +# Bitwise on single-element arrays +a1 = np.array([True]) +b1 = np.array([False]) +assert (a1 & b1).tolist() == [False], 'bitwise and single' +assert (a1 | b1).tolist() == [True], 'bitwise or single' +assert (a1 ^ b1).tolist() == [True], 'bitwise xor single' + +# Bitwise int operations +ia = np.array([15, 255, 0]) +ib = np.array([240, 15, 255]) +assert (ia & ib).tolist() == [0, 15, 0], 'int bitwise and' +assert (ia | ib).tolist() == [255, 255, 255], 'int bitwise or' +assert (ia ^ ib).tolist() == [255, 240, 255], 'int bitwise xor' + +# Matmul 2D +m1 = np.array([[1, 0], [0, 1]]) # identity matrix +m2 = np.array([[5, 6], [7, 8]]) +assert (m1 @ m2).tolist() == [[5, 6], [7, 8]], 'matmul identity' +assert np.matmul(m1, m2).tolist() == [[5, 6], [7, 8]], 'np.matmul identity' + +# In-place compound +a = np.array([10.0, 20.0, 30.0]) +a += 5 +assert a.tolist() == [15.0, 25.0, 35.0], 'iadd then check' +a -= 5 +assert a.tolist() == [10.0, 20.0, 30.0], 'isub restore' +a *= 2 +assert a.tolist() == [20.0, 40.0, 60.0], 'imul double' +a /= 4 +assert a.tolist() == [5.0, 10.0, 15.0], 'idiv quarter' + +# === Setitem edge cases === +a = np.array([1, 2, 3, 4, 5]) +a[0] = 99 +assert a[0] == 99, 'setitem first' +a[4] = 88 +assert a[4] == 88, 'setitem last' +a[-1] = 77 +assert a[-1] == 77, 'setitem negative' + +# Setitem with bool mask +b = np.array([10, 20, 30, 40, 50]) +mask = np.array([1, 0, 1, 0, 1]) +b[mask > 0] = 0 +assert b.tolist() == [0, 20, 0, 40, 0], 'setitem mask' + +# Setitem with slice +c = np.array([1, 2, 3, 4, 5]) +c[1:4] = np.array([20, 30, 40]) +assert c.tolist() == [1, 20, 30, 40, 5], 'setitem slice' + +# === Contains edge cases === +a = np.array([1.0, 2.0, 3.0]) +assert 1.0 in a, 'contains float True' +assert 4.0 not in a, 'contains float False' +assert 2 in np.array([1, 2, 3]), 'contains int True' + +# === Iter edge cases === +a = np.array([10, 20, 30]) +collected = [] +for x in a: + collected.append(x) +assert collected == [10, 20, 30], 'iter collect' +assert sum(np.array([1, 2, 3, 4])) == 10, 'sum via iter' + +# === Additional trig/math coverage === +# arcsin edge +r = np.arcsin(np.array([-1.0, 0.0, 1.0])) +assert abs(r.tolist()[0] + 1.5707963267948966) < 1e-10, 'arcsin -1' +assert abs(r.tolist()[1]) < 1e-10, 'arcsin 0' +assert abs(r.tolist()[2] - 1.5707963267948966) < 1e-10, 'arcsin 1' + +# arccos edge +r = np.arccos(np.array([-1.0, 0.0, 1.0])) +assert abs(r.tolist()[0] - 3.141592653589793) < 1e-10, 'arccos -1' +assert abs(r.tolist()[1] - 1.5707963267948966) < 1e-10, 'arccos 0' +assert abs(r.tolist()[2]) < 1e-10, 'arccos 1' + +# deg2rad/rad2deg roundtrip +angles = np.array([0.0, 45.0, 90.0, 180.0, 360.0]) +r = np.rad2deg(np.deg2rad(angles)) +assert abs(r.tolist()[0]) < 1e-10, 'roundtrip 0' +assert abs(r.tolist()[1] - 45.0) < 1e-10, 'roundtrip 45' +assert abs(r.tolist()[2] - 90.0) < 1e-10, 'roundtrip 90' +assert abs(r.tolist()[3] - 180.0) < 1e-10, 'roundtrip 180' +assert abs(r.tolist()[4] - 360.0) < 1e-10, 'roundtrip 360' + +# sign on integers +assert np.sign(np.array([-5, -1, 0, 1, 5])).tolist() == [-1, -1, 0, 1, 1], 'sign int values' + +# square on integers +assert np.square(np.array([0, 1, 2, 3])).tolist() == [0, 1, 4, 9], 'square int' + +# reciprocal edge +assert np.reciprocal(np.array([1.0, 0.5, 0.25])).tolist() == [1.0, 2.0, 4.0], 'reciprocal inverse' + +# fmod edge +assert np.fmod(np.array([10.0, -10.0, 10.0]), np.array([3.0, 3.0, -3.0])).tolist()[0] == 1.0, 'fmod pos' + +# rint edge cases +assert np.rint(np.array([0.0, 1.0, -1.0])).tolist() == [0.0, 1.0, -1.0], 'rint integers' +assert np.rint(np.array([0.1, 0.9, -0.9])).tolist() == [0.0, 1.0, -1.0], 'rint near-integers' + +# === NaN-aware edge cases === +# nansum on all-nan +assert np.nansum(np.array([float('nan'), float('nan')])) == 0.0, 'nansum all nan' + +# nanprod on no-nan +assert np.nanprod(np.array([2.0, 3.0, 4.0])) == 24.0, 'nanprod no nan' + +# nancumsum on no-nan +assert np.nancumsum(np.array([1.0, 2.0, 3.0])).tolist() == [1.0, 3.0, 6.0], 'nancumsum no nan' + +# nancumprod on no-nan +assert np.nancumprod(np.array([1.0, 2.0, 3.0])).tolist() == [1.0, 2.0, 6.0], 'nancumprod no nan' + +# nan_to_num with custom behavior +a = np.array([1.0, float('nan'), 3.0]) +r = np.nan_to_num(a) +assert r.tolist() == [1.0, 0.0, 3.0], 'nan_to_num simple' + +# === Logical function edge cases === +# logical_and with float arrays (nonzero = True) +assert np.logical_and(np.array([1.0, 0.0, 3.0]), np.array([1.0, 1.0, 0.0])).tolist() == [True, False, False], ( + 'logical_and float' +) +assert np.logical_or(np.array([0.0, 0.0, 1.0]), np.array([0.0, 1.0, 0.0])).tolist() == [False, True, True], ( + 'logical_or float' +) +assert np.logical_not(np.array([0, 1, 0, 1])).tolist() == [True, False, True, False], 'logical_not int' +assert np.logical_xor(np.array([0, 0, 1, 1]), np.array([0, 1, 0, 1])).tolist() == [False, True, True, False], ( + 'logical_xor int' +) + +# allclose with different tolerances +assert np.allclose(np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0])) == True, 'allclose identical' +assert np.allclose(np.array([1.0]), np.array([1.1])) == False, 'allclose diff 0.1' + +# isclose more cases +r = np.isclose(np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 3.0])) +assert r.tolist() == [True, True, True], 'isclose all same' + +# isin more cases +r = np.isin(np.array([1, 2, 3, 4, 5]), np.array([1, 3, 5])) +assert r.tolist() == [True, False, True, False, True], 'isin odd numbers' + +# === Manipulation edge cases === +# flip empty +assert np.flip(np.array([])).tolist() == [], 'flip empty' +assert np.flip(np.array([42.0])).tolist() == [42.0], 'flip single' + +# roll edge cases +assert np.roll(np.array([1, 2, 3]), 3).tolist() == [1, 2, 3], 'roll full cycle' +assert np.roll(np.array([1, 2, 3]), 6).tolist() == [1, 2, 3], 'roll double cycle' + +# delete edge +assert np.delete(np.array([1, 2, 3, 4, 5]), 0).tolist() == [2, 3, 4, 5], 'delete first' +assert np.delete(np.array([1, 2, 3, 4, 5]), 4).tolist() == [1, 2, 3, 4], 'delete last' + +# insert at boundaries +assert np.insert(np.array([2, 3, 4]), 0, 1).tolist() == [1, 2, 3, 4], 'insert at 0' +assert np.insert(np.array([1, 2, 3]), 3, 4).tolist() == [1, 2, 3, 4], 'insert at end' + +# diag of identity +assert np.diag(np.eye(3)).tolist() == [1.0, 1.0, 1.0], 'diag of eye' + +# trace of identity +assert np.trace(np.eye(4)) == 4.0, 'trace of eye 4' +assert np.trace(np.eye(1)) == 1.0, 'trace of eye 1' + +# flatnonzero edge cases +assert np.flatnonzero(np.array([1, 2, 3])).tolist() == [0, 1, 2], 'flatnonzero all' +assert np.flatnonzero(np.array([0])).tolist() == [], 'flatnonzero single zero' + +# === Set operation edge cases === +# intersect1d with no overlap +assert np.intersect1d(np.array([1, 2]), np.array([3, 4])).tolist() == [], 'intersect1d empty' + +# union1d with overlap +assert np.union1d(np.array([1, 2, 3]), np.array([2, 3, 4])).tolist() == [1.0, 2.0, 3.0, 4.0], 'union1d overlap' + +# setdiff1d with no diff +assert np.setdiff1d(np.array([1, 2]), np.array([1, 2])).tolist() == [], 'setdiff1d empty' + +# setxor1d with no overlap +assert np.setxor1d(np.array([1, 2]), np.array([3, 4])).tolist() == [1.0, 2.0, 3.0, 4.0], 'setxor1d disjoint' + +# === Creation function edge cases === +# logspace single element +r = np.logspace(0, 0, 1) +assert abs(r.tolist()[0] - 1.0) < 1e-10, 'logspace single' + +# geomspace single element +r = np.geomspace(5, 5, 1) +assert abs(r.tolist()[0] - 5.0) < 1e-10, 'geomspace single' + +# tri 1x1 +assert np.tri(1).tolist() == [[1.0]], 'tri 1x1' + +# identity edge +assert np.identity(2).tolist() == [[1.0, 0.0], [0.0, 1.0]], 'identity 2' + +# gradient 3 elements +r = np.gradient(np.array([0.0, 1.0, 4.0])) +assert r.tolist() == [1.0, 2.0, 3.0], 'gradient quadratic' + +# interp at boundaries +xp = np.array([0.0, 1.0]) +fp = np.array([0.0, 10.0]) +assert np.interp(np.array([0.5]), xp, fp).tolist() == [5.0], 'interp midpoint' +assert np.interp(np.array([-1.0]), xp, fp).tolist() == [0.0], 'interp below' +assert np.interp(np.array([2.0]), xp, fp).tolist() == [10.0], 'interp above' + +# === Method edge cases === +# item on single element +assert np.array([42]).item() == 42, 'item int' +assert np.array([3.14]).item() == 3.14, 'item float' + +# cumprod +assert np.array([1, 2, 3, 4, 5]).cumprod().tolist() == [1, 2, 6, 24, 120], 'cumprod method' +assert np.array([2.0]).cumprod().tolist() == [2.0], 'cumprod single' + +# squeeze on array without 1-dims +a = np.array([1, 2, 3]) +assert a.squeeze().tolist() == [1, 2, 3], 'squeeze no effect 1d' + +# take +a = np.array([10, 20, 30, 40, 50]) +assert a.take(np.array([0, 2, 4])).tolist() == [10, 30, 50], 'take indices' +assert a.take(np.array([4, 3, 2, 1, 0])).tolist() == [50, 40, 30, 20, 10], 'take reversed' + +# diagonal on identity +assert np.eye(3).diagonal().tolist() == [1.0, 1.0, 1.0], 'diagonal eye' + +# trace on identity +assert np.eye(5).trace() == 5.0, 'trace eye 5' + +# fill +a = np.array([1, 2, 3]) +a.fill(0) +assert a.tolist() == [0, 0, 0], 'fill zeros' +a.fill(99) +assert a.tolist() == [99, 99, 99], 'fill 99' + +# compress +a = np.array([10, 20, 30, 40, 50]) +cond = np.array([1, 0, 1, 0, 1]) +assert a.compress(cond).tolist() == [10, 30, 50], 'compress method' +assert a.compress(np.array([0, 0, 0, 0, 0])).tolist() == [], 'compress all false' + +# swapaxes on 1D (no-op) +a = np.array([1, 2, 3]) +assert a.swapaxes(0, 0).tolist() == [1, 2, 3], 'swapaxes 1d noop' + +# nbytes / itemsize +a = np.array([1.0, 2.0, 3.0]) +assert a.nbytes == 24, 'nbytes 3 floats' +assert a.itemsize == 8, 'itemsize float' +a5 = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) +assert a5.nbytes == 40, 'nbytes 5 floats' + +# === Additional array creation and attribute coverage === +# zeros shape +assert np.zeros(0).tolist() == [], 'zeros 0' +assert np.zeros(1).tolist() == [0.0], 'zeros 1' +assert np.zeros(5).shape == (5,), 'zeros 5 shape' + +# ones shape +assert np.ones(0).tolist() == [], 'ones 0' +assert np.ones(1).tolist() == [1.0], 'ones 1' + +# full edge +assert np.full(3, 0.0).tolist() == [0.0, 0.0, 0.0], 'full 3 zeros' +assert np.full(1, 42.0).tolist() == [42.0], 'full 1' +assert np.full(0, 1.0).tolist() == [], 'full 0' + +# arange edge +assert np.arange(0).tolist() == [], 'arange 0' +assert np.arange(1).tolist() == [0], 'arange 1' +assert np.arange(1, 1).tolist() == [], 'arange empty range' +assert np.arange(5, 0, -1).tolist() == [5, 4, 3, 2, 1], 'arange reverse' + +# linspace edge +assert np.linspace(0, 1, 2).tolist() == [0.0, 1.0], 'linspace 2' +assert np.linspace(0, 1, 1).tolist() == [0.0], 'linspace 1' +assert np.linspace(5, 5, 3).tolist() == [5.0, 5.0, 5.0], 'linspace same' + +# === Additional aggregation coverage === +# sum/mean/min/max on various sizes +assert np.sum(np.array([1])) == 1, 'sum single' +assert np.mean(np.array([5.0])) == 5.0, 'mean single' +assert np.min(np.array([42])) == 42, 'min single' +assert np.max(np.array([42])) == 42, 'max single' + +# std/var on uniform +assert np.std(np.array([5.0, 5.0, 5.0])) == 0.0, 'std uniform' +assert np.var(np.array([5.0, 5.0, 5.0])) == 0.0, 'var uniform' + +# prod +assert np.array([1, 2, 3, 4]).prod() == 24, 'prod method' +assert np.array([0, 1, 2]).prod() == 0, 'prod with zero' + +# median +assert np.median(np.array([1.0])) == 1.0, 'median single' +assert np.median(np.array([1.0, 3.0])) == 2.0, 'median two' +assert np.median(np.array([3.0, 1.0, 2.0])) == 2.0, 'median unsorted' + +# cumsum edge +assert np.cumsum(np.array([1])).tolist() == [1], 'cumsum single' + +# === Additional comparison coverage === +a = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) +assert (a > 3).tolist() == [False, False, False, True, True], 'gt 3' +assert (a >= 3).tolist() == [False, False, True, True, True], 'ge 3' +assert (a < 3).tolist() == [True, True, False, False, False], 'lt 3' +assert (a <= 3).tolist() == [True, True, True, False, False], 'le 3' +assert (a == 3).tolist() == [False, False, True, False, False], 'eq 3' +assert (a != 3).tolist() == [True, True, False, True, True], 'ne 3' + +# === Chained operations coverage === +# Sort then cumsum +a = np.sort(np.array([3, 1, 4, 1, 5])) +assert np.cumsum(a).tolist() == [1, 2, 5, 9, 14], 'sort then cumsum' + +# Where then sum +a = np.array([1, 2, 3, 4, 5]) +b = np.where(a > 2, a, np.zeros(5)) +assert np.sum(b) == 12.0, 'where then sum' + +# Unique then sort (already sorted by unique) +u = np.unique(np.array([5, 3, 1, 3, 5, 1])) +assert u.tolist() == [1.0, 3.0, 5.0], 'unique sorted' +assert len(u.tolist()) == 3, 'unique count' + +# Concatenate then reshape +c = np.concatenate([np.array([1, 2, 3]), np.array([4, 5, 6])]) +r = np.reshape(c, [2, 3]) +assert r.tolist() == [[1, 2, 3], [4, 5, 6]], 'concat then reshape' + +# === 2D coverage === +m = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +assert m.shape == (3, 3), '3x3 shape' +assert m.ndim == 2, '3x3 ndim' +assert m.size == 9, '3x3 size' +assert m.T.tolist() == [[1, 4, 7], [2, 5, 8], [3, 6, 9]], '3x3 transpose' +assert m.flatten().tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9], '3x3 flatten' +assert m.ravel().tolist() == [1, 2, 3, 4, 5, 6, 7, 8, 9], '3x3 ravel' + +# 2D sum, mean +assert m.sum() == 45, '3x3 sum' +assert m.mean() == 5.0, '3x3 mean' +assert m.min() == 1, '3x3 min' +assert m.max() == 9, '3x3 max' + +# 2D operations +m2 = np.array([[1, 2], [3, 4]]) +assert (m2 + m2).tolist() == [[2, 4], [6, 8]], '2d add' +assert (m2 * 2).tolist() == [[2, 4], [6, 8]], '2d scalar mul' +assert (m2 > 2).tolist() == [[False, False], [True, True]], '2d compare' + +# eye +e3 = np.eye(3) +assert e3.tolist() == [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], 'eye 3' +assert e3.shape == (3, 3), 'eye 3 shape' +assert e3.dtype == 'float64', 'eye 3 dtype' + +# Matmul with identity +m = np.array([[1, 2], [3, 4]]) +e = np.eye(2) +r = m @ e +assert r.tolist() == [[1.0, 2.0], [3.0, 4.0]], 'matmul with identity preserves' + +# === String repr edge cases === +assert repr(np.array([])) == 'array([], dtype=float64)', 'repr empty' +assert repr(np.array([1.0])) == 'array([1.])', 'repr single float' +assert repr(np.array([1, 2])) == 'array([1, 2])', 'repr int pair' + +# ============================================================ +# 43. ADDITIONAL COVERAGE TO REACH 1000+ ASSERTIONS +# ============================================================ + +# === astype coverage === +a = np.array([1.5, 2.7, 3.1]) +assert a.astype('int64').tolist() == [1, 2, 3], 'astype float->int truncate' +assert a.astype('int64').dtype == 'int64', 'astype float->int dtype' +assert np.array([1, 0, 1]).astype('bool').tolist() == [True, False, True], 'astype int->bool' +assert np.array([1, 0, 1]).astype('bool').dtype == 'bool', 'astype int->bool dtype' +assert np.array([0.0, 1.5, 0.0]).astype('bool').tolist() == [False, True, False], 'astype float->bool' +assert np.array([True, False]).astype('float64').tolist() == [1.0, 0.0], 'astype bool->float' +assert np.array([True, False]).astype('int64').tolist() == [1, 0], 'astype bool->int' + +# === copy method === +a = np.array([1, 2, 3]) +b = a.copy() +assert b.tolist() == [1, 2, 3], 'copy values' +b[0] = 99 +assert a[0] == 1, 'copy is independent' + +# === sort method edge cases === +a = np.array([5, 3, 1, 4, 2]) +a.sort() +assert a.tolist() == [1, 2, 3, 4, 5], 'sort in-place' + +a = np.array([1.0]) +a.sort() +assert a.tolist() == [1.0], 'sort single' + +# === Module-level sort === +assert np.sort(np.array([3, 1, 2])).tolist() == [1, 2, 3], 'np.sort int' +assert np.sort(np.array([3.0, 1.0, 2.0])).tolist() == [1.0, 2.0, 3.0], 'np.sort float' + +# === argsort method === +a = np.array([30, 10, 20]) +assert a.argsort().tolist() == [1, 2, 0], 'argsort method' + +# === argmin/argmax method === +a = np.array([3, 1, 4, 1, 5]) +assert a.argmin() == 1, 'argmin method' +assert a.argmax() == 4, 'argmax method' + +# === all/any method === +assert np.array([1, 1, 1]).all() == True, 'all true' +assert np.array([1, 0, 1]).all() == False, 'all false' +assert np.array([0, 0, 1]).any() == True, 'any true' +assert np.array([0, 0, 0]).any() == False, 'any false' + +# === dot method === +a = np.array([1, 2, 3]) +b = np.array([4, 5, 6]) +assert a.dot(b) == 32, 'dot method' + +# === clip method === +a = np.array([1, 5, 10, 15, 20]) +assert a.clip(5, 15).tolist() == [5, 5, 10, 15, 15], 'clip method' + +# === round method === +a = np.array([1.234, 5.678, 9.012]) +r = a.round() +assert r.tolist() == [1.0, 6.0, 9.0], 'round method' + +# === .T on 1D is identity === +a = np.array([1, 2, 3]) +assert a.T.tolist() == [1, 2, 3], 'T on 1D is identity' + +# === np.copy === +a = np.array([1, 2, 3]) +b = np.copy(a) +assert b.tolist() == [1, 2, 3], 'np.copy' + +# === np.count_nonzero edge cases === +assert np.count_nonzero(np.array([0, 0, 0])) == 0, 'count_nonzero all zero' +assert np.count_nonzero(np.array([1, 2, 3])) == 3, 'count_nonzero all nonzero' +assert np.count_nonzero(np.array([])) == 0, 'count_nonzero empty' + +# === np.diff edge cases === +assert np.diff(np.array([1, 4, 9, 16])).tolist() == [3, 5, 7], 'diff squares' +assert np.diff(np.array([1, 1, 1, 1])).tolist() == [0, 0, 0], 'diff constant' +assert np.diff(np.array([5, 3])).tolist() == [-2], 'diff two elem' + +# === np.where with 3 args === +a = np.array([1, 2, 3, 4, 5]) +r = np.where(a > 3, a, np.array([0, 0, 0, 0, 0])) +assert r.tolist() == [0, 0, 0, 4, 5], 'where 3-arg' + +# === np.maximum / np.minimum edge cases === +assert np.maximum(np.array([1, 5, 3]), np.array([4, 2, 6])).tolist() == [4, 5, 6], 'maximum pairwise' +assert np.minimum(np.array([1, 5, 3]), np.array([4, 2, 6])).tolist() == [1, 2, 3], 'minimum pairwise' + +# === np.concatenate more cases === +r = np.concatenate([np.array([1]), np.array([2]), np.array([3])]) +assert r.tolist() == [1, 2, 3], 'concat three' +r = np.concatenate([np.array([1, 2, 3])]) +assert r.tolist() == [1, 2, 3], 'concat single' + +# === np.stack more cases === +a = np.array([1, 2]) +b = np.array([3, 4]) +r = np.stack([a, b]) +assert r.tolist() == [[1, 2], [3, 4]], 'stack 2 arrays' + +# === np.vstack / np.hstack === +assert np.vstack([np.array([1, 2]), np.array([3, 4])]).tolist() == [[1, 2], [3, 4]], 'vstack pair' +assert np.hstack([np.array([1, 2]), np.array([3, 4])]).tolist() == [1, 2, 3, 4], 'hstack pair' + +# === np.array_equal more cases === +assert np.array_equal(np.array([1, 2, 3]), np.array([1, 2, 3])) == True, 'array_equal same' +assert np.array_equal(np.array([1, 2, 3]), np.array([1, 2, 4])) == False, 'array_equal diff' +assert np.array_equal(np.array([1, 2]), np.array([1, 2, 3])) == False, 'array_equal diff size' + +# === np.zeros_like / np.ones_like === +a = np.array([5, 10, 15]) +assert np.zeros_like(a).tolist() == [0, 0, 0], 'zeros_like int' +assert np.ones_like(a).tolist() == [1, 1, 1], 'ones_like int' +assert np.zeros_like(a).shape == (3,), 'zeros_like shape' + +# === Unary minus/plus on arrays === +a = np.array([1, -2, 3]) +assert (-a).tolist() == [-1, 2, -3], 'unary neg' +# Bitwise invert +b = np.array([0, 1, 0, 1]) +assert (~b).tolist() == [-1, -2, -1, -2], 'bitwise invert int' + +# === Section 44: In-place floor divide, modulo, power === +a = np.array([10, 7, 15]) +a //= 3 +assert a.tolist() == [3, 2, 5], 'ifloordiv scalar' +a = np.array([10, 7, 15]) +a %= 3 +assert a.tolist() == [1, 1, 0], 'imod scalar' +a = np.array([2.0, 3.0, 4.0]) +a **= 2 +assert a.tolist() == [4.0, 9.0, 16.0], 'ipow scalar' +try: + a = np.array([2, 4]) + a **= -1 + assert False, 'expected int ipow negative exponent to fail' +except ValueError as exc: + assert str(exc) == 'Integers to negative integer powers are not allowed.', 'ipow negative int exponent error' +# Array rhs +a = np.array([10, 20, 30]) +a //= np.array([3, 7, 4]) +assert a.tolist() == [3, 2, 7], 'ifloordiv array' +a = np.array([10, 20, 30]) +a %= np.array([3, 7, 4]) +assert a.tolist() == [1, 6, 2], 'imod array' +a = np.array([2.0, 3.0, 4.0]) +a **= np.array([3.0, 2.0, 0.5]) +assert a.tolist() == [8.0, 9.0, 2.0], 'ipow array' +try: + a = np.array([2, 4]) + a **= np.array([-1, -2]) + assert False, 'expected int ipow negative exponent array to fail' +except ValueError as exc: + assert str(exc) == 'Integers to negative integer powers are not allowed.', 'ipow negative int array error' + +# === Section 45: .flat attribute === +a = np.array([[1, 2], [3, 4]]) +f = a.flat +assert f.tolist() == [1, 2, 3, 4], 'flat tolist' +assert f.shape == (4,), 'flat shape' +assert f[2] == 3, 'flat indexing' + +# === Section 46: ndarray .repeat() and .nonzero() methods === +a = np.array([1, 2, 3]) +assert a.repeat(2).tolist() == [1, 1, 2, 2, 3, 3], 'repeat method 2x' +assert a.repeat(1).tolist() == [1, 2, 3], 'repeat method 1x' +a = np.array([0, 1, 0, 3, 5]) +nz = a.nonzero() +assert nz[0].tolist() == [1, 3, 4], 'nonzero method 1d' + +# === Section 47: NumPy scalar constants, dtype markers, and index tricks === +assert np.False_ == False, 'False_ matches False' +assert np.True_ == True, 'True_ matches True' +assert np.issubdtype(np.bool_, np.generic) == True, 'bool is generic' +assert np.issubdtype(np.int64, np.generic) == True, 'int64 is generic' +assert np.issubdtype(np.float64, np.generic) == True, 'float64 is generic' +assert np.issubdtype(np.bool_, np.number) == False, 'bool is not number' +assert np.issubdtype(np.int64, np.number) == True, 'int64 is number' +assert np.issubdtype(np.float64, np.number) == True, 'float64 is number' +assert np.issubdtype(np.int64, np.signedinteger) == True, 'int64 is signedinteger' +assert np.issubdtype(np.uint64, np.unsignedinteger) == True, 'uint64 is unsignedinteger' +assert np.issubdtype(np.uint64, np.signedinteger) == False, 'uint64 is not signedinteger' +assert np.issubdtype(np.int64, np.unsignedinteger) == False, 'int64 is not unsignedinteger' +assert np.issubdtype(np.float64, np.complexfloating) == False, 'float64 is not complexfloating' +assert np.issubdtype(np.float64, np.flexible) == False, 'float64 is not flexible' +assert np.issubdtype(np.float64, np.character) == False, 'float64 is not character' +assert np.issubdtype(np.character, np.flexible) == True, 'character is flexible' +assert np.issubdtype(np.complexfloating, np.inexact) == True, 'complexfloating is inexact' + +idx = np.s_[1:5:2] +assert idx.start == 1, 's_ slice start' +assert idx.stop == 5, 's_ slice stop' +assert idx.step == 2, 's_ slice step' +idx = np.index_exp[1:5:2, None] +assert len(idx) == 2, 'index_exp tuple length' +assert idx[0].start == 1, 'index_exp slice start' +assert idx[0].stop == 5, 'index_exp slice stop' +assert idx[0].step == 2, 'index_exp slice step' +assert idx[1] is None, 'index_exp preserves None' +assert np.array([10, 20, 30, 40])[np.s_[1:3]].tolist() == [20, 30], 's_ slices arrays' + +assert np.mgrid[0:3].tolist() == [0, 1, 2], 'mgrid one range' +assert np.mgrid[0:3, 0:2].tolist() == [[[0, 0], [1, 1], [2, 2]], [[0, 1], [0, 1], [0, 1]]], 'mgrid dense grids' +og = np.ogrid[0:3, 0:2] +assert og[0].tolist() == [[0], [1], [2]], 'ogrid first sparse grid' +assert og[1].tolist() == [[0, 1]], 'ogrid second sparse grid' +assert np.r_[1, 2, 3].tolist() == [1, 2, 3], 'r_ scalar concatenation' +assert np.r_[1:4, [10, 20]].tolist() == [1, 2, 3, 10, 20], 'r_ range and list' +assert np.c_[[1, 2], [3, 4]].tolist() == [[1, 3], [2, 4]], 'c_ column stack vectors' +assert np.c_[[1, 2, 3]].tolist() == [[1], [2], [3]], 'c_ single vector column' diff --git a/crates/monty/tests/parse_errors.rs b/crates/monty/tests/parse_errors.rs index 7fa491268..7e031d170 100644 --- a/crates/monty/tests/parse_errors.rs +++ b/crates/monty/tests/parse_errors.rs @@ -411,8 +411,8 @@ fn run_and_get_exc_type(code: &str) -> ExcType { #[test] fn matrix_multiplication_returns_not_implemented_error() { - // The @ operator (matrix multiplication) is not supported at runtime - assert_eq!(run_and_get_exc_type("1 @ 2"), ExcType::NotImplementedError); + // The @ operator (matrix multiplication) is implemented for ndarray but not for int/float + assert_eq!(run_and_get_exc_type("1 @ 2"), ExcType::TypeError); } #[test] diff --git a/crates/monty/tests/resource_limits.rs b/crates/monty/tests/resource_limits.rs index d96243817..8c34acbf9 100644 --- a/crates/monty/tests/resource_limits.rs +++ b/crates/monty/tests/resource_limits.rs @@ -266,6 +266,56 @@ fn memory_limit_zero() { ); } +/// Test that NumPy array constructors pre-check requested allocation size. +#[test] +fn numpy_zeros_respects_memory_limit() { + let code = r" +import numpy as np +np.zeros(200000) +"; + let ex = MontyRun::new(code.to_owned(), "test.py", vec![]).unwrap(); + + let limits = ResourceLimits::new().max_memory(100_000); + let result = ex.run(vec![], LimitedTracker::new(limits), PrintWriter::Stdout); + + assert!(result.is_err(), "large NumPy allocation should exceed memory limit"); + let exc = result.unwrap_err(); + assert_eq!(exc.exc_type(), ExcType::MemoryError); + let message = exc.message().expect("MemoryError should include a message"); + let (used, limit) = parse_memory_limit_message(message); + assert_eq!(limit, 100_000); + assert!( + used >= 1_600_000, + "np.zeros(200000) should pre-check the requested f64 data allocation, got: {message}" + ); +} + +/// Parses Monty's standard memory-limit error message into used and limit bytes. +/// +/// Memory-limit tests often need to verify the stable resource contract while +/// allowing small tracked-memory offsets from module initialization. This keeps +/// those assertions focused on the limit semantics instead of incidental heap +/// state before the checked operation. +fn parse_memory_limit_message(message: &str) -> (usize, usize) { + let rest = message + .strip_prefix("memory limit exceeded: ") + .unwrap_or_else(|| panic!("unexpected memory error message: {message}")); + let (used, limit) = rest + .split_once(" bytes > ") + .unwrap_or_else(|| panic!("unexpected memory error format: {message}")); + let limit = limit + .strip_suffix(" bytes") + .unwrap_or_else(|| panic!("unexpected memory limit suffix: {message}")); + + ( + used.parse() + .unwrap_or_else(|_| panic!("invalid used byte count in memory error: {message}")), + limit + .parse() + .unwrap_or_else(|_| panic!("invalid limit byte count in memory error: {message}")), + ) +} + #[test] fn combined_limits() { // Test multiple limits together diff --git a/examples/web_scraper/browser.py b/examples/web_scraper/browser.py index 5754bce54..dfedb6ac8 100644 --- a/examples/web_scraper/browser.py +++ b/examples/web_scraper/browser.py @@ -2,7 +2,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, AsyncIterator, Literal +from typing import TYPE_CHECKING, AsyncGenerator, Literal from playwright.async_api import Browser as PwBrowser, Page as PwPage, async_playwright @@ -13,7 +13,7 @@ @asynccontextmanager -async def start_browser() -> AsyncIterator[Browser]: +async def start_browser() -> AsyncGenerator[Browser]: async with async_playwright() as p: b = await p.chromium.launch() yield Browser(b)