From 0f264d16d870286fe3c486bbfd9fd47ad8462225 Mon Sep 17 00:00:00 2001 From: Margubur rahman Date: Mon, 11 May 2026 07:45:54 +0000 Subject: [PATCH 1/2] Refactor buffer to have zero copy --- gcsfs/extended_gcsfs.py | 63 ++++---- gcsfs/tests/test_zb_hns_utils.py | 255 +++++++++++++++++------------- gcsfs/zb_hns_utils.py | 259 ++++++++++++++++++++++--------- 3 files changed, 362 insertions(+), 215 deletions(-) diff --git a/gcsfs/extended_gcsfs.py b/gcsfs/extended_gcsfs.py index 59dd6b0d..be21748f 100644 --- a/gcsfs/extended_gcsfs.py +++ b/gcsfs/extended_gcsfs.py @@ -27,12 +27,7 @@ from gcsfs import zb_hns_utils from gcsfs.core import GCSFile, GCSFileSystem from gcsfs.retry import DEFAULT_RETRY_CONFIG, get_storage_control_retry_config -from gcsfs.zb_hns_utils import ( - DirectMemmoveBuffer, - MRDPool, - PyBytes_AsString, - PyBytes_FromStringAndSize, -) +from gcsfs.zb_hns_utils import DirectMemmoveBuffer, MRDPool from gcsfs.zonal_file import ZonalFile logger = logging.getLogger("gcsfs") @@ -411,39 +406,34 @@ async def _fetch_range_split( await mrd.close() async def _concurrent_mrd_fetch(self, offset, length, concurrency, mrd_or_pool): - """Helper to handle concurrent chunk downloads into a DirectMemmoveBuffer.""" + """Helper to handle concurrent chunk downloads cleanly.""" concurrency = ( concurrency if length >= self.MIN_CHUNK_SIZE_FOR_CONCURRENCY else 1 ) - result_bytes = PyBytes_FromStringAndSize(None, length) - buffer_ptr = PyBytes_AsString(result_bytes) - part_size = length // concurrency - tasks = [] - buffers = [] - loop = asyncio.get_running_loop() - # Track if the core download process failed + tasks = [] + views = [] has_error = False - async def _download(o, s, b, mrd_or_pool): + # The master buffer manages its own allocation under the hood + master_buffer = DirectMemmoveBuffer(length, self._memmove_executor) + + async def _download(o, s, view, mrd_or_pool): async with _get_mrd_from_pool_or_mrd(mrd_or_pool) as m_client: - await m_client.download_ranges([(o, s, b)]) + await m_client.download_ranges([(o, s, view)]) for i in range(concurrency): part_offset = offset + (i * part_size) actual_size = part_size if i < concurrency - 1 else length - (i * part_size) - part_address = buffer_ptr + (part_offset - offset) - buf = DirectMemmoveBuffer( - part_address, - part_address + actual_size, - self._memmove_executor, - ) - buffers.append(buf) + # Give each task a restricted view of the master buffer + view = master_buffer.get_view(part_offset - offset, actual_size) + views.append(view) + tasks.append( asyncio.create_task( - _download(part_offset, actual_size, buf, mrd_or_pool) + _download(part_offset, actual_size, view, mrd_or_pool) ) ) @@ -453,6 +443,8 @@ async def _download(o, s, b, mrd_or_pool): if isinstance(res, Exception): has_error = True raise res + for view in views: + view.close() except BaseException: has_error = True for t in tasks: @@ -461,18 +453,17 @@ async def _download(o, s, b, mrd_or_pool): await asyncio.gather(*tasks, return_exceptions=True) raise finally: - for buf in buffers: - try: - await loop.run_in_executor(None, buf.close) - except BufferError: - # If we are already handling a network/download exception, - # ignore the BufferError (which is just a symptom of the drop). - # If there's no download error, this means the buffer logic - # itself failed, so we must surface the error. - if not has_error: - raise - - return result_bytes + try: + master_buffer.close() + except Exception: + # If we are already handling a network/download exception, + # ignore the exception from buffer (which is just a symptom of the drop). + # If there's no download error, this means the buffer logic + # itself failed, so we must surface the error. + if not has_error: + raise + + return master_buffer.get_value() async def _cat_file( self, diff --git a/gcsfs/tests/test_zb_hns_utils.py b/gcsfs/tests/test_zb_hns_utils.py index 4733116f..0b06ae0b 100644 --- a/gcsfs/tests/test_zb_hns_utils.py +++ b/gcsfs/tests/test_zb_hns_utils.py @@ -1,5 +1,4 @@ import concurrent.futures -import ctypes import logging from unittest import mock @@ -380,104 +379,6 @@ async def test_mrd_pool_close_with_exceptions(create_mrd_mock, mock_gcsfs): assert len(pool._all_mrds) == 0 -@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove") -def test_direct_memmove_buffer_error_handling(mock_memmove): - size = 20 - buffer_array = (ctypes.c_char * size)() - start_address = ctypes.addressof(buffer_array) - end_address = start_address + size - - # Simulate an access violation or similar error during memory copy - mock_memmove.side_effect = MemoryError("Segfault simulated") - - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) - - # First write triggers the background error - future = buf.write(b"bad data") - - # Wait for the background thread to actually fail - with pytest.raises(MemoryError): - future.result() - - # Subsequent writes should raise the stored error immediately - with pytest.raises(MemoryError, match="Segfault simulated"): - buf.write(b"more data") - - # Close should also raise the stored error. - with pytest.raises(MemoryError, match="Segfault simulated"): - buf.close() - - executor.shutdown() - - -def test_direct_memmove_buffer(): - data1 = b"hello" - data2 = b"world" - - # Calculate exact size to prevent the new underflow check from failing - size = len(data1) + len(data2) - buffer_array = (ctypes.c_char * size)() - start_address = ctypes.addressof(buffer_array) - end_address = start_address + size - - executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) - buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) - - future1 = buf.write(data1) - future2 = buf.write(data2) - - future1.result() - future2.result() - buf.close() - - result_bytes = ctypes.string_at(start_address, len(data1) + len(data2)) - assert result_bytes == b"helloworld" - - executor.shutdown() - - -def test_direct_memmove_buffer_overflow(): - """Tests that writing past the allocated end_address raises a BufferError.""" - size = 10 - buffer_array = (ctypes.c_char * size)() - start_address = ctypes.addressof(buffer_array) - end_address = start_address + size - - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) - - # Fill the buffer exactly to capacity - buf.write(b"1234567890") - - # Attempting to write even 1 more byte should trigger the overflow protection - with pytest.raises(BufferError, match="Attempted to write"): - buf.write(b"1") - - buf.close() - executor.shutdown() - - -def test_direct_memmove_buffer_underflow(): - """Tests that closing an incompletely filled buffer raises a BufferError.""" - size = 10 - buffer_array = (ctypes.c_char * size)() - start_address = ctypes.addressof(buffer_array) - end_address = start_address + size - - executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) - - # Write fewer bytes than the expected capacity - buf.write(b"12345") - - # Closing should detect that current_offset (5) < expected size (10) - with pytest.raises(BufferError, match="Buffer contains uninitialized data"): - buf.close() - - executor.shutdown() - - @pytest.mark.asyncio async def test_mrd_pool_queue_filled_during_lock_wait(mock_gcsfs): pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=1) @@ -547,6 +448,103 @@ async def fake_create_mrd(): assert pool._rr_index == 1 +@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove") +def test_direct_memmove_buffer_error_handling(mock_memmove): + # Use a size > 128KB to trigger the executor background path + size = 130 * 1024 + 10 + data1 = b"a" * (130 * 1024) + data2 = b"b" * 10 + + # Simulate an access violation or similar error during memory copy + mock_memmove.side_effect = MemoryError("Segfault simulated") + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(size, executor, max_pending=2) + view = buf.get_view(0, size) + + # First write triggers the background error (slow path) + future = view.write(data1) + + # Wait for the background thread to actually fail + with pytest.raises(MemoryError): + future.result() + + # Subsequent writes should raise the stored error immediately + with pytest.raises(MemoryError, match="Segfault simulated"): + view.write(data2) + + # Close should also raise the stored error. + with pytest.raises(MemoryError, match="Segfault simulated"): + buf.close() + + executor.shutdown() + + +def test_direct_memmove_buffer(): + data1 = b"hello" + data2 = b"world" + size = len(data1) + len(data2) + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + buf = DirectMemmoveBuffer(size, executor, max_pending=2) + view = buf.get_view(0, size) + + future1 = view.write(data1) + future2 = view.write(data2) + + future1.result() + future2.result() + + view.close() + buf.close() + + result_bytes = buf.get_value() + assert result_bytes == b"helloworld" + + executor.shutdown() + + +def test_direct_memmove_buffer_overflow(): + """Tests that writing past the view boundaries raises a BufferError.""" + size = 10 + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(size, executor, max_pending=2) + view = buf.get_view(0, size) + + # Fill the buffer exactly to capacity + view.write(b"1234567890") + + # Attempting to write even 1 more byte should trigger the overflow protection + with pytest.raises(BufferError, match="Attempted to write"): + view.write(b"1") + + view.close() + buf.close() + executor.shutdown() + + +def test_direct_memmove_buffer_underflow(): + """Tests that closing an incompletely filled view/buffer raises a BufferError.""" + size = 10 + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(size, executor, max_pending=2) + view = buf.get_view(0, size) + + # Write fewer bytes than the expected capacity + view.write(b"12345") + + # Closing the view should detect that current_offset (5) < expected size (10) + with pytest.raises(BufferError, match="Buffer contains uninitialized data"): + view.close() + + # Calling get_value after an incompletely filled buffer should also error + buf.close() + with pytest.raises(BufferError, match="Buffer incomplete"): + buf.get_value() + + executor.shutdown() + + @mock.patch("gcsfs.zb_hns_utils.ctypes.memmove") def test_direct_memmove_buffer_submit_failure(mock_memmove): """ @@ -554,13 +552,17 @@ def test_direct_memmove_buffer_submit_failure(mock_memmove): the internal locks, semaphores, and events are properly reset, and close() does not hang. """ - size = 10 - buffer_array = (ctypes.c_char * size)() - start_address = ctypes.addressof(buffer_array) - end_address = start_address + size + # 1. Chunk > 128KB to force executor scheduling (skip the synchronous fast path) + chunk_size = 130 * 1024 + + # 2. Expected size > chunk_size to skip the Zero-Copy optimization + expected_size = 140 * 1024 + + data = b"a" * chunk_size executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) - buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2) + buf = DirectMemmoveBuffer(expected_size, executor, max_pending=2) + view = buf.get_view(0, expected_size) # Mock the submit method to simulate a closed executor throwing a RuntimeError with mock.patch.object( @@ -568,7 +570,7 @@ def test_direct_memmove_buffer_submit_failure(mock_memmove): ): # The write operation should raise the simulated RuntimeError with pytest.raises(RuntimeError, match="Executor closed"): - buf.write(b"12345") + view.write(data) # Verify that the internal tracking state was correctly rolled back assert buf._pending_count == 0 @@ -579,3 +581,46 @@ def test_direct_memmove_buffer_submit_failure(mock_memmove): buf.close() executor.shutdown() + + +def test_direct_memmove_buffer_zero_copy(): + """Tests that a perfect aligned single payload avoids memory allocation completely.""" + data = b"exact_size_payload" + size = len(data) + + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(size, executor, max_pending=2) + view = buf.get_view(0, size) + + # Writing a single payload identical to the expected size + future = view.write(data) + future.result() + + view.close() + buf.close() + + # Should be the EXACT same string object returned without copying + result = buf.get_value() + assert result is data + + executor.shutdown() + + +def test_direct_memmove_buffer_overlapping_views(): + """Tests that getting overlapping views raises a ValueError.""" + size = 100 + executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) + buf = DirectMemmoveBuffer(size, executor, max_pending=2) + + # Get a view for the first half + _ = buf.get_view(0, 50) + + # Attempting to get an overlapping view should fail + with pytest.raises(ValueError, match="Overlapping view requested"): + _ = buf.get_view(25, 50) + + # Getting a view for the second half should succeed + _ = buf.get_view(50, 50) + + buf.close() + executor.shutdown() diff --git a/gcsfs/zb_hns_utils.py b/gcsfs/zb_hns_utils.py index 1b85005d..f6176e64 100644 --- a/gcsfs/zb_hns_utils.py +++ b/gcsfs/zb_hns_utils.py @@ -1,4 +1,5 @@ import asyncio +import concurrent.futures import contextlib import ctypes import logging @@ -182,33 +183,47 @@ class DirectMemmoveBuffer: """ A buffer-like object that writes data directly to memory asynchronously. - This class provides a `write` interface that queues `ctypes.memmove` operations - to a thread pool executor, limiting the maximum number of concurrent pending - writes using a semaphore. It is useful for high-performance data transfers - where memory copies need to be offloaded from the main thread. + This class provides an interface that queues `ctypes.memmove` operations + to a thread pool executor. It provides synchronous backpressure: if `max_pending` + operations are currently writing, the `write()` call will safely block the + calling thread (e.g., an asyncio loop) until capacity frees up. + + Memory allocation is natively deferred. If the payload precisely aligns + with expected bounds sequentially, it gracefully overrides manual memmoves + using true Zero-Copy payload replacement safely under the hood. + + Note: This class is now strictly Thread-Safe """ - def __init__(self, start_address, end_address, executor, max_pending=5): + THRESHOLD_BYTES_FOR_SCHEDULING = 128 * 1024 + + def __init__(self, expected_size, executor, max_pending=5): """ Initializes the DirectMemmoveBuffer. Args: - start_address (int): The starting memory address where data will be written. - end_address (int): The absolute ending memory address. Writes exceeding - this boundary will be rejected to prevent overflows. + expected_size (int): The total amount of bytes expected to populate memory. executor (concurrent.futures.Executor): The thread pool executor to run the memmove operations. The lifecycle of this executor is managed by the caller. max_pending (int, optional): The maximum number of pending write operations allowed in the queue. Defaults to 5. """ - self.start_address = start_address - self.end_address = end_address + self.expected_size = expected_size self.executor = executor # Volatile state variables. Must only be amended while holding self._lock. - self.current_offset = 0 self._pending_count = 0 self._error = None + self._total_bytes_written = 0 + self._stop_accepting_writes = False + self._is_closed = False + + # Track allocated (start, end) intervals to prevent overlapping views. + self._allocated_intervals = [] + + # PyBytes Native Pointers & Allocation tracking natively handled + self._result_bytes = None + self._start_address = None # Primitives: # 1. semaphore: Provides backpressure by limiting the number of active tasks. @@ -219,6 +234,74 @@ def __init__(self, start_address, end_address, executor, max_pending=5): self._done_event = threading.Event() self._done_event.set() + class PartialView: + """A bounded memory writer providing robust overfill/underfill constraint validations.""" + + def __init__(self, parent, start_offset, expected_size): + self.parent = parent + self.start_offset = start_offset + self.expected_size = expected_size + self.current_offset = 0 + self._view_lock = threading.Lock() + + def write(self, data): + """ + Schedules a write operation to memory mapping. + """ + if not isinstance(data, bytes): + raise ValueError(f"Expected bytes, but got {type(data)}") + size = len(data) + with self._view_lock: + if self.current_offset + size > self.expected_size: + error_msg = ( + f"Attempted to write {size} bytes " + f"at offset {self.current_offset}. " + f"Max capacity is {self.expected_size} bytes." + ) + raise BufferError(error_msg) + + abs_offset = self.start_offset + self.current_offset + self.current_offset += size + + return self.parent._submit_write(abs_offset, data, size) + + def close(self): + """ + Validates boundaries enforcing complete local payload consistency. + """ + with self._view_lock: + if self.current_offset < self.expected_size: + error_msg = ( + f"Expected {self.expected_size} bytes, " + f"but only received {self.current_offset} bytes. " + f"Buffer contains uninitialized data." + ) + raise BufferError(error_msg) + + def get_view(self, offset, size): + """Constructs secure mapped offset references correctly handling constraint layouts.""" + if offset < 0 or offset + size > self.expected_size: + raise ValueError("Invalid view requested: exceeds physical boundaries!") + + start = offset + end = offset + size + + with self._lock: + if self._stop_accepting_writes or self._is_closed: + raise ValueError("Cannot get view on a closed/closing buffer.") + + # Enforce Write-Once memory semantics: prevent overlapping views + for a_start, a_end in self._allocated_intervals: + if max(start, a_start) < min(end, a_end): + raise ValueError( + f"Overlapping view requested: [{start}, {end}) " + f"overlaps with already allocated view [{a_start}, {a_end})" + ) + + self._allocated_intervals.append((start, end)) + + return self.PartialView(self, offset, size) + def _decrement_pending(self): """Helper to cleanly release concurrency primitives after a task finishes.""" self.semaphore.release() @@ -227,86 +310,114 @@ def _decrement_pending(self): if self._pending_count == 0: self._done_event.set() - def write(self, data): - """ - Schedules a write operation to memory. - - Calculates the destination address based on the current offset, increments the offset, - and submits the memory move operation to the executor. Blocks if the number of - pending operations reaches `max_pending`. - - Args: - data: The data to be written to memory. Must support the buffer protocol. - - Returns: - concurrent.futures.Future: A future object representing the execution of the - memory move operation. - - Raises: - Exception: If any previous asynchronous write operation encountered an error. - BufferError: If the write exceeds the allocated memory boundaries. - """ - if self._error: - raise self._error - - size = len(data) - with self._lock: - dest = self.start_address + self.current_offset - if dest + size > self.end_address: - error_msg = ( - f"Attempted to write {size} bytes " - f"at offset {self.current_offset}. " - f"Max capacity is {self.end_address - self.start_address} bytes." - ) - raise BufferError(error_msg) - - self.current_offset += size - data_bytes = bytes(data) if not isinstance(data, bytes) else data - + def _submit_write(self, dest_offset, data_bytes, size): self.semaphore.acquire() - with self._lock: - if self._pending_count == 0: - self._done_event.clear() - self._pending_count += 1 try: - return self.executor.submit(self._do_memmove, dest, data_bytes, size) + with self._lock: + if self._stop_accepting_writes or self._is_closed: + raise ValueError("I/O operation on closed buffer.") + + if self._error: + raise self._error + + if self._result_bytes is None: + if dest_offset == 0 and size == self.expected_size: + self._result_bytes = data_bytes + self.semaphore.release() # Release because we skip the executor + fut = concurrent.futures.Future() + fut.set_result(None) + self._total_bytes_written += size + return fut + else: + self._result_bytes = PyBytes_FromStringAndSize( + None, self.expected_size + ) + self._start_address = PyBytes_AsString(self._result_bytes) + + # Defensive programming: gracefully catch internal overwrite attempts + if self._start_address is None: + raise BufferError( + "Attempted to execute standard write over a Zero-Copied payload." + ) + + dest = self._start_address + dest_offset + if self._pending_count == 0: + self._done_event.clear() + self._pending_count += 1 + + except BaseException: + self.semaphore.release() + raise + + try: + if size <= self.THRESHOLD_BYTES_FOR_SCHEDULING: + # Fast path, no need to send it to executor + self._do_memmove(dest, data_bytes, size) + fut = concurrent.futures.Future() + with self._lock: + local_err = self._error + if local_err: + fut.set_exception(local_err) + else: + fut.set_result(None) + return fut + else: + # Slow path, schedule it on executor. + return self.executor.submit(self._do_memmove, dest, data_bytes, size) except BaseException as e: - self._error = e + with self._lock: + self._error = e self._decrement_pending() raise e def _do_memmove(self, dest, data_bytes, size): try: + with self._lock: + if self._error: + return + ctypes.memmove(dest, data_bytes, size) - except Exception as e: - self._error = e - raise e + with self._lock: + self._total_bytes_written += size + + except BaseException as e: + with self._lock: + if self._error is None: + self._error = e + raise finally: self._decrement_pending() + def get_value(self): + with self._lock: + if self._error: + raise self._error + if not self._is_closed: + raise RuntimeError("Buffer is still not closed yet!") + if self._result_bytes is None and self.expected_size == 0: + return b"" + if self._total_bytes_written < self.expected_size: + raise BufferError( + f"Buffer incomplete: Expected {self.expected_size} bytes but " + f"only populated {self._total_bytes_written}. Returning this " + f"payload would leak uninitialized memory." + ) + return self._result_bytes + def close(self): """ - Waits for all pending write operations to complete and checks for errors. - Blocks the calling thread until the queue of memory operations is entirely - processed. - - Raises: - Exception: If any background write operation failed during execution. - BufferError: If the buffer was not filled to the expected capacity. + Locks the buffer preventing further incoming writes, waits for all pending + write operations to complete, and checks for errors. """ + with self._lock: + self._stop_accepting_writes = True + self._done_event.wait() - if self._error: - raise self._error - - expected_size = self.end_address - self.start_address - if self.current_offset < expected_size: - error_msg = ( - f"Expected {expected_size} bytes, " - f"but only received {self.current_offset} bytes. " - f"Buffer contains uninitialized data." - ) - raise BufferError(error_msg) + with self._lock: + self._is_closed = True + if self._error: + raise self._error class MRDPool: From 8c8f2a5737d5d810025aaae16cd51a3efc27c0e0 Mon Sep 17 00:00:00 2001 From: Margubur rahman Date: Tue, 12 May 2026 09:32:56 +0000 Subject: [PATCH 2/2] move fast path out of try-catch --- gcsfs/zb_hns_utils.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/gcsfs/zb_hns_utils.py b/gcsfs/zb_hns_utils.py index f6176e64..f40d9828 100644 --- a/gcsfs/zb_hns_utils.py +++ b/gcsfs/zb_hns_utils.py @@ -350,26 +350,31 @@ def _submit_write(self, dest_offset, data_bytes, size): self.semaphore.release() raise - try: - if size <= self.THRESHOLD_BYTES_FOR_SCHEDULING: - # Fast path, no need to send it to executor + if size <= self.THRESHOLD_BYTES_FOR_SCHEDULING: + # Fast path, no need to send it to executor + try: self._do_memmove(dest, data_bytes, size) - fut = concurrent.futures.Future() - with self._lock: - local_err = self._error - if local_err: - fut.set_exception(local_err) - else: - fut.set_result(None) - return fut + except BaseException: + # The exception is already captured in self._error by _do_memmove + pass + + fut = concurrent.futures.Future() + with self._lock: + local_err = self._error + if local_err: + fut.set_exception(local_err) else: + fut.set_result(None) + return fut + else: + try: # Slow path, schedule it on executor. return self.executor.submit(self._do_memmove, dest, data_bytes, size) - except BaseException as e: - with self._lock: - self._error = e - self._decrement_pending() - raise e + except BaseException as e: + with self._lock: + self._error = e + self._decrement_pending() + raise e def _do_memmove(self, dest, data_bytes, size): try: