diff --git a/pedalboard/ArrayUtils.h b/pedalboard/ArrayUtils.h new file mode 100644 index 000000000..cfa0cb1d9 --- /dev/null +++ b/pedalboard/ArrayUtils.h @@ -0,0 +1,121 @@ +/* + * pedalboard + * Copyright 2025 Spotify AB + * + * Licensed under the GNU Public License, Version 3.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.gnu.org/licenses/gpl-3.0.html + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace py = pybind11; + +namespace Pedalboard { + +/** + * Utility function to convert various array-like objects to py::array. + * Supports: + * - NumPy arrays (pass-through) + * - PyTorch tensors (via .numpy() method) + * - JAX arrays (via .__array__() method) + * - TensorFlow tensors (via .numpy() method) + * - CuPy arrays (via .get() method for CPU copy) + * - Any object with __array__ interface + */ +inline py::array ensureArrayLike(py::object input) { + // If we were already given a numpy array, just return it + if (py::isinstance(input)) { + return py::reinterpret_borrow(input); + } + + // Check if it's a PyTorch tensor (has a .numpy() method) + if (py::hasattr(input, "numpy") && py::hasattr(input, "dtype") && + py::hasattr(input, "device")) { + // Check if tensor is on CPU + py::object device = input.attr("device"); + std::string device_type = py::str(device.attr("type")).cast(); + + if (device_type != "cpu") { + // Move tensor to CPU first + input = input.attr("cpu")(); + } + + // Call .numpy() to get the numpy array + // This shares memory with the tensor when possible + return input.attr("numpy")().cast(); + } + + // Check if it's a TensorFlow tensor (has .numpy() method but no .device) + if (py::hasattr(input, "numpy") && !py::hasattr(input, "device")) { + try { + return input.attr("numpy")().cast(); + } catch (...) { + // Fall through to next option + } + } + + // Check if it's a CuPy array (has .get() method) + if (py::hasattr(input, "get") && py::hasattr(input, "dtype") && + py::hasattr(input, "ndim")) { + try { + // .get() copies from GPU to CPU and returns numpy array + return input.attr("get")().cast(); + } catch (...) { + // Fall through to next option + } + } + + // Check if it implements the array protocol (__array__) + if (py::hasattr(input, "__array__")) { + try { + return input.attr("__array__")().cast(); + } catch (...) { + // Fall through to error + } + } + + // Try to convert directly to array as a last resort + // py::array::ensure() will attempt to convert the object to an array + // or return an invalid array handle if conversion fails + py::array result = py::array::ensure(input); + + if (!result) { + throw py::type_error( + "Expected an array-like object (numpy array, torch tensor, etc.), " + "but received: " + + py::repr(input).cast()); + } + + return result; +} + +/** + * Template version that ensures the array has a specific dtype + */ +template +inline py::array_t ensureArrayLikeWithType(py::object input) { + py::array arr = ensureArrayLike(input); + + // If the array already has the correct type, return it + if (py::isinstance>(arr)) { + return py::reinterpret_borrow>(arr); + } + + // Otherwise, cast to the desired type + // Note: this may create a copy + return arr.cast>(); +} + +} // namespace Pedalboard diff --git a/pedalboard/ExternalPlugin.h b/pedalboard/ExternalPlugin.h index 98584e03e..55339ad9a 100644 --- a/pedalboard/ExternalPlugin.h +++ b/pedalboard/ExternalPlugin.h @@ -1578,10 +1578,9 @@ inline void init_external_plugins(py::module &m) { py::arg("reset") = true) .def( "process", - [](std::shared_ptr self, const py::array inputArray, + [](std::shared_ptr self, py::object input, double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, - reset); + return process(input, sampleRate, {self}, bufferSize, reset); }, EXTERNAL_PLUGIN_PROCESS_DOCSTRING, py::arg("input_array"), py::arg("sample_rate"), @@ -1589,10 +1588,9 @@ inline void init_external_plugins(py::module &m) { py::arg("reset") = true) .def( "__call__", - [](std::shared_ptr self, const py::array inputArray, + [](std::shared_ptr self, py::object input, double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, - reset); + return process(input, sampleRate, {self}, bufferSize, reset); }, "Run an audio or MIDI buffer through this plugin, returning " "audio. Alias for :py:meth:`process`.", @@ -1809,18 +1807,18 @@ example: a Windows VST3 plugin bundle will not load on Linux or macOS.) SHOW_EDITOR_DOCSTRING, py::arg("close_event") = py::none()) .def( "process", - [](std::shared_ptr self, const py::array inputArray, - double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, reset); + [](std::shared_ptr self, py::object input, double sampleRate, + unsigned int bufferSize, bool reset) { + return process(input, sampleRate, {self}, bufferSize, reset); }, EXTERNAL_PLUGIN_PROCESS_DOCSTRING, py::arg("input_array"), py::arg("sample_rate"), py::arg("buffer_size") = DEFAULT_BUFFER_SIZE, py::arg("reset") = true) .def( "__call__", - [](std::shared_ptr self, const py::array inputArray, - double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, reset); + [](std::shared_ptr self, py::object input, double sampleRate, + unsigned int bufferSize, bool reset) { + return process(input, sampleRate, {self}, bufferSize, reset); }, "Run an audio or MIDI buffer through this plugin, returning " "audio. Alias for :py:meth:`process`.", @@ -2035,18 +2033,18 @@ see :class:`pedalboard.VST3Plugin`.) SHOW_EDITOR_DOCSTRING, py::arg("close_event") = py::none()) .def( "process", - [](std::shared_ptr self, const py::array inputArray, - double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, reset); + [](std::shared_ptr self, py::object input, double sampleRate, + unsigned int bufferSize, bool reset) { + return process(input, sampleRate, {self}, bufferSize, reset); }, EXTERNAL_PLUGIN_PROCESS_DOCSTRING, py::arg("input_array"), py::arg("sample_rate"), py::arg("buffer_size") = DEFAULT_BUFFER_SIZE, py::arg("reset") = true) .def( "__call__", - [](std::shared_ptr self, const py::array inputArray, - double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, reset); + [](std::shared_ptr self, py::object input, double sampleRate, + unsigned int bufferSize, bool reset) { + return process(input, sampleRate, {self}, bufferSize, reset); }, "Run an audio or MIDI buffer through this plugin, returning " "audio. Alias for :py:meth:`process`.", diff --git a/pedalboard/io/AudioFileInit.h b/pedalboard/io/AudioFileInit.h index f447277e2..c83d70386 100644 --- a/pedalboard/io/AudioFileInit.h +++ b/pedalboard/io/AudioFileInit.h @@ -23,6 +23,7 @@ #include #include +#include "../ArrayUtils.h" #include "../JuceHeader.h" #include "AudioFile.h" @@ -278,16 +279,19 @@ inline void init_audio_file( py::arg("format") = py::none()) .def_static( "encode", - [](const py::array samples, double sampleRate, std::string format, + [](py::object samples, double sampleRate, std::string format, int numChannels, int bitDepth, std::optional> quality) { + // Convert the input to a numpy array (supports torch tensors, etc.) + py::array samplesArray = ensureArrayLike(samples); + juce::MemoryBlock outputBlock; auto audioFile = std::make_unique( format, std::make_unique(outputBlock, false), sampleRate, numChannels, bitDepth, quality); - audioFile->write(samples); + audioFile->write(samplesArray); audioFile->close(); return py::bytes((const char *)outputBlock.getData(), @@ -299,6 +303,10 @@ inline void init_audio_file( R"( Encode an audio buffer to a Python :class:`bytes` object. +The input audio buffer can be any array-like object, including NumPy arrays, +PyTorch tensors, TensorFlow tensors, JAX arrays, or any other object that +supports the buffer protocol or has a __array__ method. + This function will encode an entire audio buffer at once and return a :class:`bytes` object representing the bytes of the resulting audio file. diff --git a/pedalboard/io/WriteableAudioFile.h b/pedalboard/io/WriteableAudioFile.h index 8a874bad3..5b8e0c36a 100644 --- a/pedalboard/io/WriteableAudioFile.h +++ b/pedalboard/io/WriteableAudioFile.h @@ -23,6 +23,7 @@ #include #include +#include "../ArrayUtils.h" #include "../BufferUtils.h" #include "../JuceHeader.h" #include "AudioFile.h" @@ -445,11 +446,15 @@ class WriteableAudioFile /** * A generic type-dispatcher for all writes. + * Accepts various array-like objects including torch tensors. * pybind11 supports dispatch here, but both pybind11-stubgen * and Sphinx currently (2022-07-16) struggle with how to render * docstrings of overloaded functions, so we don't overload. */ - void write(py::array inputArray) { + void write(py::object input) { + // Convert the input to a numpy array (supports torch tensors, etc.) + py::array inputArray = ensureArrayLike(input); + switch (inputArray.dtype().char_()) { case 'f': return write(py::array_t(inputArray.release(), false)); @@ -1017,13 +1022,15 @@ inline void init_writeable_audio_file( py::arg("format") = py::none()) .def( "write", - [](WriteableAudioFile &file, py::array samples) { + [](WriteableAudioFile &file, py::object samples) { file.write(samples); }, - py::arg("samples").noconvert(), + py::arg("samples"), "Encode an array of audio data and write " "it to this file. The number of channels in the array must match the " - "number of channels used to open the file. The array may contain " + "number of channels used to open the file. The audio data may be " + "provided as a NumPy array, PyTorch tensor, TensorFlow tensor, " + "JAX array, or any other array-like object. The array may contain " "audio in any shape. If the file's bit depth or format does not " "match the provided data type, the audio will be automatically " "converted.\n\n" diff --git a/pedalboard/process.h b/pedalboard/process.h index 22e0e2653..81c5c5020 100644 --- a/pedalboard/process.h +++ b/pedalboard/process.h @@ -21,6 +21,7 @@ #include #include +#include "ArrayUtils.h" #include "BufferUtils.h" #include "Plugin.h" #include "PluginContainer.h" @@ -275,9 +276,12 @@ processFloat32(const py::array_t inputArray, inputArray.request().ndim); } -py::array_t process(py::array inputArray, double sampleRate, +py::array_t process(py::object input, double sampleRate, const std::vector> plugins, unsigned int bufferSize, bool reset) { + // Convert the input to a numpy array (supports torch tensors, etc.) + py::array inputArray = ensureArrayLike(input); + py::array_t float32InputArray; switch (inputArray.dtype().char_()) { case 'f': diff --git a/pedalboard/python_bindings.cpp b/pedalboard/python_bindings.cpp index fd83cf059..1462a2aa0 100644 --- a/pedalboard/python_bindings.cpp +++ b/pedalboard/python_bindings.cpp @@ -87,16 +87,20 @@ PYBIND11_MODULE(pedalboard_native, m, py::mod_gil_not_used()) { m.def( "process", - [](const py::array inputArray, double sampleRate, + [](py::object input, double sampleRate, const std::vector> plugins, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, plugins, bufferSize, reset); + return process(input, sampleRate, plugins, bufferSize, reset); }, R"( Run a 32-bit or 64-bit floating point audio buffer through a list of Pedalboard plugins. If the provided buffer uses a 64-bit datatype, it will be converted to 32-bit for processing. +The input audio buffer can be any array-like object, including NumPy arrays, +PyTorch tensors, TensorFlow tensors, JAX arrays, or any other object that +supports the buffer protocol or has a __array__ method. + The provided ``buffer_size`` argument will be used to control the size of each chunk of audio provided into the plugins. Higher buffer sizes may speed up processing at the expense of memory usage. @@ -134,15 +138,19 @@ or buffer, set ``reset`` to ``False``. "parameters will remain unchanged. ") .def( "process", - [](std::shared_ptr self, const py::array inputArray, - double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, reset); + [](std::shared_ptr self, py::object input, double sampleRate, + unsigned int bufferSize, bool reset) { + return process(input, sampleRate, {self}, bufferSize, reset); }, R"( Run a 32-bit or 64-bit floating point audio buffer through this plugin. (If calling this multiple times with multiple plugins, consider creating a :class:`pedalboard.Pedalboard` object instead.) +The input audio buffer can be any array-like object, including NumPy arrays, +PyTorch tensors, TensorFlow tensors, JAX arrays, or any other object that +supports the buffer protocol or has a __array__ method. + The returned array may contain up to (but not more than) the same number of samples as were provided. If fewer samples were returned than expected, the plugin has likely buffered audio inside itself. To receive the remaining @@ -176,9 +184,9 @@ If the number of samples and the number of channels are the same, each py::arg("buffer_size") = DEFAULT_BUFFER_SIZE, py::arg("reset") = true) .def( "__call__", - [](std::shared_ptr self, const py::array inputArray, - double sampleRate, unsigned int bufferSize, bool reset) { - return process(inputArray, sampleRate, {self}, bufferSize, reset); + [](std::shared_ptr self, py::object input, double sampleRate, + unsigned int bufferSize, bool reset) { + return process(input, sampleRate, {self}, bufferSize, reset); }, "Run an audio buffer through this plugin. Alias for " ":py:meth:`process`.", diff --git a/pedalboard_native/__init__.pyi b/pedalboard_native/__init__.pyi index 79523f3be..f0eedb786 100644 --- a/pedalboard_native/__init__.pyi +++ b/pedalboard_native/__init__.pyi @@ -4,7 +4,7 @@ For audio I/O classes (i.e.: reading and writing audio files), see ``pedalboard. from __future__ import annotations import pedalboard_native -import enum + import typing original_overload = typing.overload @@ -26,11 +26,23 @@ def patch_overload(func): typing.overload = patch_overload +from typing import Optional +from typing_extensions import Literal from enum import Enum import threading -from numpy import ndarray, float32 -from numpy.typing import NDArray +# Array-like type that includes numpy arrays, torch tensors, etc. +# At runtime, we accept any array-like object (numpy arrays, torch tensors, +# tensorflow tensors, jax arrays, or anything with __array__ method). +# For type checking, we use numpy.typing.ArrayLike which covers most cases. +if typing.TYPE_CHECKING: + import numpy + import numpy.typing + + ArrayLike = numpy.typing.ArrayLike +else: + ArrayLike = typing.Any +import numpy _Shape = typing.Tuple[int, ...] @@ -44,7 +56,6 @@ __all__ = [ "Delay", "Distortion", "ExternalPlugin", - "ExternalPluginReloadType", "GSMFullRateCompressor", "Gain", "HighShelfFilter", @@ -77,27 +88,31 @@ class Plugin: def __call__( self, - audio: NDArray[float32], + input_array: ArrayLike, sample_rate: float, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Run an audio buffer through this plugin. Alias for :py:meth:`process`. """ def process( self, - input_array: NDArray[float32], + input_array: ArrayLike, sample_rate: float, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Run a 32-bit or 64-bit floating point audio buffer through this plugin. (If calling this multiple times with multiple plugins, consider creating a :class:`pedalboard.Pedalboard` object instead.) + The input audio buffer can be any array-like object, including NumPy arrays, + PyTorch tensors, TensorFlow tensors, JAX arrays, or any other object that + supports the buffer protocol or has a __array__ method. + The returned array may contain up to (but not more than) the same number of samples as were provided. If fewer samples were returned than expected, the plugin has likely buffered audio inside itself. To receive the remaining @@ -319,7 +334,7 @@ class Convolution(Plugin): def __init__( self, impulse_response_filename: typing.Union[ - str, NDArray[float32] + str, numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]] ], mix: float = 1.0, sample_rate: typing.Optional[float] = None, @@ -328,7 +343,7 @@ class Convolution(Plugin): @property def impulse_response( self, - ) -> typing.Optional[NDArray[float32]]: + ) -> typing.Optional[numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]]: """ """ @property @@ -396,36 +411,6 @@ class Distortion(Plugin): pass pass - -MIDIMessageLike = typing.Union[ - "mido.Message", - typing.Tuple[bytes, float], - typing.Tuple[typing.List[int], float], -] - -class ExternalPluginReloadType(enum.Enum): - """ - Unknown: we need to determine the reload type. - """ - Unknown = 0 - - """ - Most plugins are of this type: calling .reset() on them will clear their - internal state. This is quick and easy: to start processing a new buffer, - all we need to do is call .reset() and optionally prepareToPlay(). - """ - ClearsAudioOnReset = 1 - - """ - This plugin type is a bit more of a pain to deal with; it could be argued - * that plugins that don't clear their internal buffers when reset() is called - * are buggy. To start processing a new buffer, we'll have to find another way - * to clear the buffer, usually by reloading the plugin from scratch and - persisting its parameters somehow. - """ - PersistsAudioOnReset = 2 - - class ExternalPlugin(Plugin): """ A wrapper around a third-party effect plugin. @@ -436,34 +421,37 @@ class ExternalPlugin(Plugin): @typing.overload def __call__( self, - midi_messages: typing.List[MIDIMessageLike], + midi_messages: object, duration: float, - sample_rate: float | int, + sample_rate: float, num_channels: int = 2, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: ... + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: + """ + Run an audio or MIDI buffer through this plugin, returning audio. Alias for :py:meth:`process`. + + Run an audio or MIDI buffer through this plugin, returning audio. Alias for :py:meth:`process`. + """ @typing.overload def __call__( self, - input_array: NDArray[float32], - sample_rate: float | int, + input_array: ArrayLike, + sample_rate: float, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: - """ - Run an audio or MIDI buffer through this plugin, returning audio. Alias for :py:meth:`process`. - """ - + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: ... @typing.overload def process( self, - input_array: NDArray[float32], - sample_rate: float | int, + midi_messages: object, + duration: float, + sample_rate: float, + num_channels: int = 2, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Pass a buffer of audio (as a 32- or 64-bit NumPy array) *or* a list of MIDI messages to this plugin, returning audio. @@ -636,22 +624,13 @@ class ExternalPlugin(Plugin): @typing.overload def process( self, - midi_messages: typing.List[MIDIMessageLike], - duration: float, - sample_rate: float | int, - num_channels: int = 2, + input_array: ArrayLike, + sample_rate: float, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: ... + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: ... pass - @staticmethod - def get_plugin_names_for_file(filename: str) -> typing.List[str]: - """ - Return a list of plugin names contained within a given VST3 plugin (i.e.: a ".vst3"). If the provided file cannot be scanned, an ImportError will be raised. - """ - # Not implemented by the abstract base class, but implemented by the concrete subclasses. - class Gain(Plugin): """ A gain plugin that increases or decreases the volume of a signal by amplifying or attenuating it by the provided value (in decibels). No distortion or other effects are applied. @@ -1100,11 +1079,13 @@ class AudioUnitPlugin(ExternalPlugin): @typing.overload def __call__( self, - input_array: NDArray[float32], - sample_rate: float | int, + midi_messages: object, + duration: float, + sample_rate: float, + num_channels: int = 2, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Run an audio or MIDI buffer through this plugin, returning audio. Alias for :py:meth:`process`. @@ -1114,13 +1095,11 @@ class AudioUnitPlugin(ExternalPlugin): @typing.overload def __call__( self, - midi_messages: typing.List[MIDIMessageLike], - duration: float, - sample_rate: float | int, - num_channels: int = 2, + input_array: ArrayLike, + sample_rate: float, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: ... + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: ... def __init__( self, path_to_plugin_file: str, @@ -1141,11 +1120,13 @@ class AudioUnitPlugin(ExternalPlugin): @typing.overload def process( self, - input_array: NDArray[float32], - sample_rate: float | int, + midi_messages: object, + duration: float, + sample_rate: float, + num_channels: int = 2, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Pass a buffer of audio (as a 32- or 64-bit NumPy array) *or* a list of MIDI messages to this plugin, returning audio. @@ -1318,13 +1299,11 @@ class AudioUnitPlugin(ExternalPlugin): @typing.overload def process( self, - midi_messages: typing.List[MIDIMessageLike], - duration: float, - sample_rate: float | int, - num_channels: int = 2, + input_array: ArrayLike, + sample_rate: float, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: ... + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: ... def show_editor(self, close_event: typing.Optional[threading.Event] = None) -> None: """ Show the UI of this plugin as a native window. @@ -1362,20 +1341,6 @@ class AudioUnitPlugin(ExternalPlugin): def _parameters(self) -> typing.List[_AudioProcessorParameter]: """ """ - @property - def _reload_type(self) -> ExternalPluginReloadType: - """ - The behavior that this plugin exhibits when .reset() is called. This is an internal attribute which gets set on plugin instantiation and should only be accessed for debugging and testing. - - - """ - - @_reload_type.setter - def _reload_type(self, arg0: ExternalPluginReloadType) -> None: - """ - The behavior that this plugin exhibits when .reset() is called. This is an internal attribute which gets set on plugin instantiation and should only be accessed for debugging and testing. - """ - @property def category(self) -> str: """ @@ -1812,6 +1777,30 @@ class VST3Plugin(ExternalPlugin): *Support for running VST3® plugins on background threads introduced in v0.8.8.* """ + @typing.overload + def __call__( + self, + midi_messages: object, + duration: float, + sample_rate: float, + num_channels: int = 2, + buffer_size: int = 8192, + reset: bool = True, + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: ... + @typing.overload + def __call__( + self, + input_array: ArrayLike, + sample_rate: float, + buffer_size: int = 8192, + reset: bool = True, + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: + """ + Run an audio or MIDI buffer through this plugin, returning audio. Alias for :py:meth:`process`. + + Run an audio or MIDI buffer through this plugin, returning audio. Alias for :py:meth:`process`. + """ + def __init__( self, path_to_plugin_file: str, @@ -1835,11 +1824,21 @@ class VST3Plugin(ExternalPlugin): @typing.overload def process( self, - input_array: NDArray[float32], - sample_rate: float | int, + midi_messages: object, + duration: float, + sample_rate: float, + num_channels: int = 2, buffer_size: int = 8192, reset: bool = True, - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: ... + @typing.overload + def process( + self, + input_array: ArrayLike, + sample_rate: float, + buffer_size: int = 8192, + reset: bool = True, + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Pass a buffer of audio (as a 32- or 64-bit NumPy array) *or* a list of MIDI messages to this plugin, returning audio. @@ -2009,16 +2008,6 @@ class VST3Plugin(ExternalPlugin): """ - @typing.overload - def process( - self, - midi_messages: typing.List[MIDIMessageLike], - duration: float, - sample_rate: float | int, - num_channels: int = 2, - buffer_size: int = 8192, - reset: bool = True, - ) -> NDArray[float32]: ... def show_editor(self, close_event: typing.Optional[threading.Event] = None) -> None: """ Show the UI of this plugin as a native window. @@ -2056,20 +2045,6 @@ class VST3Plugin(ExternalPlugin): def _parameters(self) -> typing.List[_AudioProcessorParameter]: """ """ - @property - def _reload_type(self) -> ExternalPluginReloadType: - """ - The behavior that this plugin exhibits when .reset() is called. This is an internal attribute which gets set on plugin instantiation and should only be accessed for debugging and testing. - - - """ - - @_reload_type.setter - def _reload_type(self, arg0: ExternalPluginReloadType) -> None: - """ - The behavior that this plugin exhibits when .reset() is called. This is an internal attribute which gets set on plugin instantiation and should only be accessed for debugging and testing. - """ - @property def category(self) -> str: """ @@ -2328,17 +2303,21 @@ class _AudioProcessorParameter: pass def process( - input_array: ndarray, + input_array: ArrayLike, sample_rate: float, plugins: typing.List[Plugin], buffer_size: int = 8192, reset: bool = True, -) -> NDArray[float32]: +) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Run a 32-bit or 64-bit floating point audio buffer through a list of Pedalboard plugins. If the provided buffer uses a 64-bit datatype, it will be converted to 32-bit for processing. + The input audio buffer can be any array-like object, including NumPy arrays, + PyTorch tensors, TensorFlow tensors, JAX arrays, or any other object that + supports the buffer protocol or has a __array__ method. + The provided ``buffer_size`` argument will be used to control the size of each chunk of audio provided into the plugins. Higher buffer sizes may speed up processing at the expense of memory usage. diff --git a/pedalboard_native/_internal/__init__.pyi b/pedalboard_native/_internal/__init__.pyi index d97a206f5..86090aa33 100644 --- a/pedalboard_native/_internal/__init__.pyi +++ b/pedalboard_native/_internal/__init__.pyi @@ -22,9 +22,22 @@ def patch_overload(func): typing.overload = patch_overload +from typing import Optional from typing_extensions import Literal from enum import Enum import threading + +# Array-like type that includes numpy arrays, torch tensors, etc. +# At runtime, we accept any array-like object (numpy arrays, torch tensors, +# tensorflow tensors, jax arrays, or anything with __array__ method). +# For type checking, we use numpy.typing.ArrayLike which covers most cases. +if typing.TYPE_CHECKING: + import numpy + import numpy.typing + + ArrayLike = numpy.typing.ArrayLike +else: + ArrayLike = typing.Any import pedalboard_native __all__ = [ diff --git a/pedalboard_native/io/__init__.pyi b/pedalboard_native/io/__init__.pyi index 4ad836986..7fa5b318d 100644 --- a/pedalboard_native/io/__init__.pyi +++ b/pedalboard_native/io/__init__.pyi @@ -7,11 +7,47 @@ import pedalboard_native.io import typing +original_overload = typing.overload +__OVERLOADED_DOCSTRINGS = {} + +def patch_overload(func): + original_overload(func) + if func.__doc__: + __OVERLOADED_DOCSTRINGS[func.__qualname__] = func.__doc__ + else: + func.__doc__ = __OVERLOADED_DOCSTRINGS.get(func.__qualname__) + if func.__doc__: + # Work around the fact that pybind11-stubgen generates + # duplicate docstrings sometimes, once for each overload: + docstring = func.__doc__ + if docstring[len(docstring) // 2 :].strip() == docstring[: -len(docstring) // 2].strip(): + func.__doc__ = docstring[len(docstring) // 2 :].strip() + return func + +typing.overload = patch_overload + +from typing import Optional from typing_extensions import Literal -import numpy as np +from enum import Enum +import threading +import io + +# Array-like type that includes numpy arrays, torch tensors, etc. +# At runtime, we accept any array-like object (numpy arrays, torch tensors, +# tensorflow tensors, jax arrays, or anything with __array__ method). +# For type checking, we use numpy.typing.ArrayLike which covers most cases. +if typing.TYPE_CHECKING: + import numpy + import numpy.typing + + ArrayLike = numpy.typing.ArrayLike +else: + ArrayLike = typing.Any +import numpy + +# Type alias for file-like objects that can be used with AudioFile +FileLike = typing.Union[typing.BinaryIO, io.BytesIO, io.BufferedIOBase, io.RawIOBase] import pedalboard_native.utils -from numpy.typing import NDArray -from numpy import float32 _Shape = typing.Tuple[int, ...] @@ -122,52 +158,41 @@ class AudioFile: those classes below for documentation. """ - @classmethod - @typing.overload - def __new__(cls, filename: str) -> ReadableAudioFile: - """Open an audio file for reading (mode 'r' is implied).""" - ... - - @classmethod - @typing.overload - def __new__(cls, filename: str, mode: Literal["r"]) -> ReadableAudioFile: - """Open an audio file for reading with an explicit mode 'r'.""" - ... - - @classmethod - @typing.overload - def __new__( - cls, file_like: typing.Union[typing.BinaryIO, memoryview], mode: Literal["r"] = "r" - ) -> ReadableAudioFile: ... - - @classmethod - @typing.overload - def __new__( - cls, - filename: str, - mode: Literal["w"], - samplerate: typing.Optional[float] = None, - num_channels: int = 1, - bit_depth: int = 16, - quality: typing.Optional[typing.Union[str, float]] = None, - ) -> WriteableAudioFile: ... - - @classmethod - @typing.overload + # Overloads for type narrowing based on mode + # Write mode overload (most parameters, must come first) + @original_overload + @staticmethod def __new__( - cls, - file_like: typing.BinaryIO, + cls: typing.Type[AudioFile], + filename_or_file_like: typing.Union[str, FileLike], mode: Literal["w"], - samplerate: typing.Optional[float] = None, + samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None, ) -> WriteableAudioFile: ... - + # Read mode with explicit "r" + @original_overload + @staticmethod + def __new__( + cls: typing.Type[AudioFile], + filename_or_file_like: typing.Union[str, FileLike, memoryview], + mode: Literal["r"], + ) -> ReadableAudioFile: ... + # Read mode default (no mode argument) + @original_overload + @staticmethod + def __new__( + cls: typing.Type[AudioFile], + filename_or_file_like: typing.Union[str, FileLike, memoryview], + ) -> ReadableAudioFile: ... + # Use *args/**kwargs for __init__ to bypass argument validation + # since the actual __init__ is called on the subclass returned by __new__ + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: ... @staticmethod def encode( - samples: NDArray[typing.Union[np.int8, np.int16, np.int32, np.float32, np.float64]], + samples: ArrayLike, samplerate: float, format: str, num_channels: int = 1, @@ -177,6 +202,10 @@ class AudioFile: """ Encode an audio buffer to a Python :class:`bytes` object. + The input audio buffer can be any array-like object, including NumPy arrays, + PyTorch tensors, TensorFlow tensors, JAX arrays, or any other object that + supports the buffer protocol or has a __array__ method. + This function will encode an entire audio buffer at once and return a :class:`bytes` object representing the bytes of the resulting audio file. @@ -203,6 +232,7 @@ class AudioFile: """ pass + class AudioStream: """ A class that allows interacting with live streams of audio from an input @@ -299,7 +329,7 @@ class AudioStream: @staticmethod def play( - audio: NDArray[float32], + audio: numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]], sample_rate: float, output_device_name: typing.Optional[str] = None, ) -> None: @@ -307,7 +337,7 @@ class AudioStream: Play audio data to the speaker, headphones, or other output device. This method will block until the audio is finished playing. """ - def read(self, num_samples: int = 0) -> NDArray[float32]: + def read(self, num_samples: int = 0) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ .. warning:: Recording audio is a **real-time** operation, so if your code doesn't call :py:meth:`read` quickly enough, some audio will be lost. To warn about this, :py:meth:`read` will throw an exception if audio data is dropped. This behavior can be disabled by setting :py:attr:`ignore_dropped_input` to :py:const:`True`. The number of dropped samples since the last call to :py:meth:`read` can be retrieved by accessing the :py:attr:`dropped_input_frame_count` property. @@ -319,7 +349,7 @@ class AudioStream: """ def write( - self, audio: NDArray[float32], sample_rate: float + self, audio: numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]], sample_rate: float ) -> None: """ If the provided sample rate does not match the output device's sample rate, an error will be thrown. In this case, you can use :py:class:`StreamResampler` to resample the audio before calling :py:meth:`write`. @@ -460,23 +490,32 @@ class ReadableAudioFile(AudioFile): Stop using this :class:`ReadableAudioFile` as a context manager, close the file, release its resources. """ - @typing.overload - def __init__(self, filename: str) -> None: ... - @typing.overload - def __init__(self, file_like: typing.Union[typing.BinaryIO, memoryview]) -> None: ... - - # These don't exist, but Pyright assumes they do: - @typing.overload - def __init__(self, filename: str, mode: Literal["r"]) -> None: ... - @typing.overload - def __init__(self, file_like: typing.Union[typing.BinaryIO, memoryview], mode: Literal["r"]) -> None: ... - - @classmethod - @typing.overload - def __new__(cls, filename: str) -> ReadableAudioFile: ... - @classmethod - @typing.overload - def __new__(cls, file_like: typing.Union[typing.BinaryIO, memoryview]) -> ReadableAudioFile: ... + # Direct call: ReadableAudioFile(file) + @original_overload + def __new__( + cls, + filename_or_file_like: typing.Union[str, FileLike, memoryview], + ) -> ReadableAudioFile: ... + # Via AudioFile: AudioFile(file, "r") + @original_overload + def __new__( + cls, + filename_or_file_like: typing.Union[str, FileLike, memoryview], + mode: Literal["r"], + ) -> ReadableAudioFile: ... + # Direct call: ReadableAudioFile(file) + @original_overload + def __init__( + self, + filename_or_file_like: typing.Union[str, FileLike, memoryview], + ) -> None: ... + # Via AudioFile: AudioFile(file, "r") + @original_overload + def __init__( + self, + filename_or_file_like: typing.Union[str, FileLike, memoryview], + mode: Literal["r"], + ) -> None: ... def __repr__(self) -> str: ... def close(self) -> None: """ @@ -485,7 +524,7 @@ class ReadableAudioFile(AudioFile): def read( self, num_frames: typing.Union[float, int] = 0 - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Read the given number of frames (samples in each channel) from this audio file at its current position. @@ -513,9 +552,7 @@ class ReadableAudioFile(AudioFile): an exception will be thrown, as a fractional number of samples cannot be returned. """ - def read_raw( - self, num_frames: typing.Union[float, int] = 0 - ) -> NDArray[typing.Union[np.int8, np.int16, np.int32, np.float32]]: + def read_raw(self, num_frames: typing.Union[float, int] = 0) -> numpy.ndarray: """ Read the given number of frames (samples in each channel) from this audio file at its current position. @@ -743,9 +780,9 @@ class ResampledReadableAudioFile(AudioFile): target_sample_rate: float, resampling_quality: pedalboard_native.Resample.Quality = pedalboard_native.Resample.Quality.WindowedSinc32, ) -> None: ... - @classmethod + @staticmethod def __new__( - cls, + cls: object, audio_file: ReadableAudioFile, target_sample_rate: float, resampling_quality: pedalboard_native.Resample.Quality = pedalboard_native.Resample.Quality.WindowedSinc32, @@ -758,7 +795,7 @@ class ResampledReadableAudioFile(AudioFile): def read( self, num_frames: typing.Union[float, int] = 0 - ) -> NDArray[float32]: + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Read the given number of frames (samples in each channel, at the target sample rate) from this audio file at its current position, automatically resampling on-the-fly to @@ -930,8 +967,8 @@ class StreamResampler: def __repr__(self) -> str: ... def process( - self, input: typing.Optional[NDArray[float32]] = None - ) -> NDArray[float32]: + self, input: typing.Optional[numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]] = None + ) -> numpy.ndarray[typing.Any, numpy.dtype[numpy.float32]]: """ Resample a 32-bit floating-point audio buffer. The returned buffer may be smaller than the provided buffer depending on the quality method used. Call :meth:`process()` without any arguments to flush the internal buffers and return all remaining audio. """ @@ -1030,74 +1067,52 @@ class WriteableAudioFile(AudioFile): def __enter__(self) -> WriteableAudioFile: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... - @typing.overload - def __init__( - self, - filename: str, - samplerate: float, - num_channels: int = 1, - bit_depth: int = 16, - quality: typing.Optional[typing.Union[str, float]] = None, - ) -> None: ... - @typing.overload - def __init__( - self, - file_like: typing.BinaryIO, - samplerate: float, - num_channels: int = 1, - bit_depth: int = 16, - quality: typing.Optional[typing.Union[str, float]] = None, - format: typing.Optional[str] = None, - ) -> None: ... - @classmethod - @typing.overload + # Direct call: WriteableAudioFile(file, samplerate, ...) + @original_overload def __new__( cls, - filename: str, - samplerate: typing.Optional[float] = None, + filename_or_file_like: typing.Union[str, FileLike], + samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, + format: typing.Optional[str] = None, ) -> WriteableAudioFile: ... - @classmethod - @typing.overload + # Via AudioFile: AudioFile(file, "w", samplerate, ...) + @original_overload def __new__( cls, - file_like: typing.BinaryIO, - samplerate: typing.Optional[float] = None, + filename_or_file_like: typing.Union[str, FileLike], + mode: Literal["w"], + samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None, ) -> WriteableAudioFile: ... - - # This overload does not actually exist; just makes Pyright happy as - # it assumes that __init__ is called with the same arguments as __new__: - @typing.overload + # Direct call: WriteableAudioFile(file, samplerate, ...) + @original_overload def __init__( self, - filename: str, - mode: Literal["w"], + filename_or_file_like: typing.Union[str, FileLike], samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, + format: typing.Optional[str] = None, ) -> None: ... - - # This overload does not actually exist; just makes Pyright happy as - # it assumes that __init__ is called with the same arguments as __new__: - @typing.overload + # Via AudioFile: AudioFile(file, "w", samplerate, ...) + @original_overload def __init__( self, - file_like: typing.BinaryIO, + filename_or_file_like: typing.Union[str, FileLike], mode: Literal["w"], - samplerate: typing.Optional[float] = None, + samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None, ) -> None: ... - def __repr__(self) -> str: ... def close(self) -> None: """ @@ -1114,9 +1129,9 @@ class WriteableAudioFile(AudioFile): Return the current position of the write pointer in this audio file, in frames at the target sample rate. This value will increase as :meth:`write` is called, and will never decrease. """ - def write(self, samples: numpy.ndarray) -> None: + def write(self, samples: ArrayLike) -> None: """ - Encode an array of audio data and write it to this file. The number of channels in the array must match the number of channels used to open the file. The array may contain audio in any shape. If the file's bit depth or format does not match the provided data type, the audio will be automatically converted. + Encode an array of audio data and write it to this file. The number of channels in the array must match the number of channels used to open the file. The audio data may be provided as a NumPy array, PyTorch tensor, TensorFlow tensor, JAX array, or any other array-like object. The array may contain audio in any shape. If the file's bit depth or format does not match the provided data type, the audio will be automatically converted. Arrays of type int8, int16, int32, float32, and float64 are supported. If an array of an unsupported ``dtype`` is provided, a ``TypeError`` will be raised. diff --git a/pedalboard_native/utils/__init__.pyi b/pedalboard_native/utils/__init__.pyi index af1b7fd00..d8e5edc21 100644 --- a/pedalboard_native/utils/__init__.pyi +++ b/pedalboard_native/utils/__init__.pyi @@ -22,9 +22,22 @@ def patch_overload(func): typing.overload = patch_overload +from typing import Optional from typing_extensions import Literal from enum import Enum import threading + +# Array-like type that includes numpy arrays, torch tensors, etc. +# At runtime, we accept any array-like object (numpy arrays, torch tensors, +# tensorflow tensors, jax arrays, or anything with __array__ method). +# For type checking, we use numpy.typing.ArrayLike which covers most cases. +if typing.TYPE_CHECKING: + import numpy + import numpy.typing + + ArrayLike = numpy.typing.ArrayLike +else: + ArrayLike = typing.Any import numpy import pedalboard_native @@ -37,9 +50,7 @@ class Chain(pedalboard_native.PluginContainer, pedalboard_native.Plugin): Run zero or more plugins as a plugin. Useful when used with the Mix plugin. """ - @typing.overload def __init__(self, plugins: typing.List[pedalboard_native.Plugin]) -> None: ... - @typing.overload def __repr__(self) -> str: ... pass @@ -48,9 +59,7 @@ class Mix(pedalboard_native.PluginContainer, pedalboard_native.Plugin): A utility plugin that allows running other plugins in parallel. All plugins provided will be mixed equally. """ - @typing.overload def __init__(self, plugins: typing.List[pedalboard_native.Plugin]) -> None: ... - @typing.overload def __repr__(self) -> str: ... pass diff --git a/scripts/generate_type_stubs_and_docs.py b/scripts/generate_type_stubs_and_docs.py index 5e72b3179..f2e985431 100644 --- a/scripts/generate_type_stubs_and_docs.py +++ b/scripts/generate_type_stubs_and_docs.py @@ -49,15 +49,120 @@ ] MULTILINE_REPLACEMENTS = [ - # Users don't want "Peter Sobot’s iPhone Microphone" to show up in their type hints: + # Users don't want "Peter Sobot's iPhone Microphone" to show up in their type hints: (r"input_device_names = \[[^]]*\]\n", "input_device_names: typing.List[str] = []\n"), ( r"output_device_names = \[[^]]*\]\n", "output_device_names: typing.List[str] = []\n", ), + # Replace AudioFile.__new__ and __init__ with overloads for proper type narrowing. + # Using original_overload (not typing.overload) because patch_overload confuses pyright. + ( + r"(class AudioFile:[\s\S]*?)(\n def __new__\([\s\S]*?-> typing\.Union\[ReadableAudioFile, WriteableAudioFile\]: \.\.\.)", + r"\1\n" + r" # Overloads for type narrowing based on mode\n" + r" @original_overload\n" + r" @staticmethod\n" + r" def __new__(\n" + r" cls: typing.Type[AudioFile],\n" + r" filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview],\n" + r" mode: Literal[\"w\"],\n" + r" samplerate: float,\n" + r" num_channels: int = 1,\n" + r" bit_depth: int = 16,\n" + r" quality: typing.Optional[typing.Union[str, float]] = None,\n" + r" format: typing.Optional[str] = None,\n" + r" ) -> WriteableAudioFile: ...\n" + r" @original_overload\n" + r" @staticmethod\n" + r" def __new__(\n" + r" cls: typing.Type[AudioFile],\n" + r" filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview],\n" + r" mode: Literal[\"r\"],\n" + r" ) -> ReadableAudioFile: ...\n" + r" @original_overload\n" + r" @staticmethod\n" + r" def __new__(\n" + r" cls: typing.Type[AudioFile],\n" + r" filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview],\n" + r" ) -> ReadableAudioFile: ...\n" + r" def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: ...", + ), + # Add mode overloads to ReadableAudioFile for calls routed through AudioFile + ( + r"(class ReadableAudioFile\(AudioFile\):[\s\S]*?)\n def __new__\(\n cls,\n filename_or_file_like:[^)]+\) -> ReadableAudioFile: \.\.\.\n def __init__\(\n self,\n filename_or_file_like:[^)]+\) -> None: \.\.\.", + r"\1\n" + r" @original_overload\n" + r" def __new__(cls, filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview]) -> ReadableAudioFile: ...\n" + r" @original_overload\n" + r" def __new__(cls, filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview], mode: Literal[\"r\"]) -> ReadableAudioFile: ...\n" + r" @original_overload\n" + r" def __init__(self, filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview]) -> None: ...\n" + r" @original_overload\n" + r" def __init__(self, filename_or_file_like: typing.Union[str, typing.BinaryIO, memoryview], mode: Literal[\"r\"]) -> None: ...", + ), + # Add mode overloads to WriteableAudioFile for calls routed through AudioFile + ( + r"(class WriteableAudioFile\(AudioFile\):[\s\S]*?def __exit__[^.]+\.\.\.)\n def __new__\(\n cls,\n filename_or_file_like:[^)]+\n samplerate:[^)]+\) -> WriteableAudioFile: \.\.\.\n def __init__\(\n self,\n filename_or_file_like:[^)]+\n samplerate:[^)]+\) -> None: \.\.\.", + r"\1\n" + r" @original_overload\n" + r" def __new__(cls, filename_or_file_like: typing.Union[str, typing.BinaryIO], samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None) -> WriteableAudioFile: ...\n" + r" @original_overload\n" + r" def __new__(cls, filename_or_file_like: typing.Union[str, typing.BinaryIO], mode: Literal[\"w\"], samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None) -> WriteableAudioFile: ...\n" + r" @original_overload\n" + r" def __init__(self, filename_or_file_like: typing.Union[str, typing.BinaryIO], samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None) -> None: ...\n" + r" @original_overload\n" + r" def __init__(self, filename_or_file_like: typing.Union[str, typing.BinaryIO], mode: Literal[\"w\"], samplerate: float, num_channels: int = 1, bit_depth: int = 16, quality: typing.Optional[typing.Union[str, float]] = None, format: typing.Optional[str] = None) -> None: ...", + ), + # Remove _reload_type property blocks (internal implementation detail using ExternalPluginReloadType) + # Match @property followed by def _reload_type and its docstring (using [\s\S] for multiline) + (r' @property\n def _reload_type\(self\)[^\n]*\n """[\s\S]*?"""\n', ""), + (r' @_reload_type\.setter\n def _reload_type\(self[^\n]*\n """[\s\S]*?"""\n', ""), + # MyPy chokes on classes that contain both __new__ and __init__. + # Remove all bare, arg-free inits (and their @typing.overload decorators): + (r" @typing\.overload\n def __init__\(self\) -> None: \.\.\.\n", ""), + # After removing one overload from a pair in Chain/Mix classes, a single @typing.overload + # remains which is invalid. Remove it ONLY for the plugins: parameter pattern (from utils) + ( + r" @typing\.overload\n( def __init__\(self, plugins:)", + r"\1", + ), + # Fix VST3Plugin __call__ overload order to match ExternalPlugin (midi_messages first) + # The raw stubs have input_array first, but parent has midi_messages first. + # Raw stubs have single-line signatures before black formatting. + ( + r'(class VST3Plugin\(ExternalPlugin, Plugin\):[\s\S]*?"""\n)' + r'( @typing\.overload\n def __call__\(self, input_array: object[^\n]*\n """[\s\S]*?"""\n)' + r"( @typing\.overload\n def __call__\(self, midi_messages[^\n]*\.\.\.)", + r"\1\3\n\2", + ), + # Fix VST3Plugin process overload order similarly + ( + r'( def load_preset\(self, preset_file_path: str\)[^\n]*\n """[\s\S]*?"""\n)' + r'( @typing\.overload\n def process\(self, input_array: object[^\n]*\n """[\s\S]*?"""\n)' + r"( @typing\.overload\n def process\(self, midi_messages[^\n]*\.\.\.)", + r"\1\3\n\2", + ), + # Fix AudioUnitPlugin __call__ and process overload order similarly + ( + r'(class AudioUnitPlugin\(ExternalPlugin, Plugin\):[\s\S]*?"""\n)' + r'( @typing\.overload\n def __call__\(self, input_array: object[^\n]*\n """[\s\S]*?"""\n)' + r"( @typing\.overload\n def __call__\(self, midi_messages[^\n]*\.\.\.)", + r"\1\3\n\2", + ), + ( + r'(class AudioUnitPlugin[\s\S]*?def load_preset\(self, preset_file_path: str\)[^\n]*\n """[\s\S]*?"""\n)' + r'( @typing\.overload\n def process\(self, input_array: object[^\n]*\n """[\s\S]*?"""\n)' + r"( @typing\.overload\n def process\(self, midi_messages[^\n]*\.\.\.)", + r"\1\3\n\2", + ), ] REPLACEMENTS = [ + # Replace generic 'object' with a proper ArrayLike type hint for audio data: + (r"input_array: object", r"input_array: ArrayLike"), + (r"input: object", r"input: ArrayLike"), + (r"samples: object", r"samples: ArrayLike"), # object is a superclass of `str`, which would make these declarations ambiguous: ( r"file_like: object, mode: str = 'r'", @@ -79,9 +184,21 @@ "\n".join( [ "import typing", + "from typing import Optional", "from typing_extensions import Literal", "from enum import Enum", "import threading", + "", + "# Array-like type that includes numpy arrays, torch tensors, etc.", + "# At runtime, we accept any array-like object (numpy arrays, torch tensors,", + "# tensorflow tensors, jax arrays, or anything with __array__ method).", + "# For type checking, we use numpy.typing.ArrayLike which covers most cases.", + "if typing.TYPE_CHECKING:", + " import numpy", + " import numpy.typing", + " ArrayLike = numpy.typing.ArrayLike", + "else:", + " ArrayLike = typing.Any", ] ), ), @@ -104,7 +221,8 @@ "pedalboard_native.Resample.Quality = pedalboard_native.Resample.Quality", ), ( - r".*: pedalboard_native.ExternalPluginReloadType.*", + # Remove any lines referencing ExternalPluginReloadType (internal implementation detail) + r".*ExternalPluginReloadType.*", "", ), # Enum values that should not be in __all__: @@ -115,8 +233,8 @@ # Remove type hints in docstrings, added unnecessarily by pybind11-stubgen (r".*?:type:.*$", ""), # MyPy chokes on classes that contain both __new__ and __init__. - # Remove all bare, arg-free inits: - (r"def __init__\(self\) -> None: ...", ""), + # Remove all bare, arg-free inits (moved to MULTILINE_REPLACEMENTS): + (r"def __init__\(self\) -> None: \.\.\.", ""), # Sphinx gets confused when inheriting twice from the same base class: (r"\(ExternalPlugin, Plugin\)", "(ExternalPlugin)"), # Python <3.9 doesn't like bare lists in type hints: diff --git a/stubtest.allowlist b/stubtest.allowlist index 65ec9b32d..95dcf5c3c 100644 --- a/stubtest.allowlist +++ b/stubtest.allowlist @@ -1,18 +1,25 @@ -pedalboard.Plugin.__call__ -pedalboard.__all__ -pedalboard.io.__all__ -pedalboard._pedalboard.Dict -pedalboard._pedalboard.Iterable -pedalboard._pedalboard.List -pedalboard._pedalboard.Plugin.__call__ -pedalboard._pedalboard.WeakTypeWrapper@\d+ -pedalboard._pedalboard._AudioProcessorParameter.__init__ -pedalboard._pedalboard.Pedalboard -pedalboard._pedalboard.VST3Plugin -pedalboard.utils.Chain -pedalboard.utils.Mix -pedalboard.utils.__all__ -pedalboard.version.__all__ -.*?AudioUnitPlugin.* +# Allowlist for stubtest - these are expected differences between stubs and runtime +# Metaclass differences are expected with pybind11 .*?metaclass.* -.*?ExternalPluginReloadType.* \ No newline at end of file +# AudioUnitPlugin is macOS-only +.*?AudioUnitPlugin.* +# ExternalPluginReloadType is internal +.*?ExternalPluginReloadType.* +.*?ClearsAudioOnReset.* +.*?PersistsAudioOnReset.* +.*?Unknown.* +# ArrayLike is a type alias for type checking only +.*?ArrayLike.* +# __all__ is stub-only +.*?__all__.* +# original_overload and patch_overload are stub helpers +.*?original_overload.* +.*?patch_overload.* +# _AudioProcessorParameter has variadic __init__ +.*?_AudioProcessorParameter.__init__.* +# Enum methods that pybind11 adds at runtime +.*?__index__.* +.*?__int__.* +.*?__members__.* +# installed_plugins is a runtime-populated list +.*?installed_plugins.* \ No newline at end of file