Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 27 additions & 36 deletions gcsfs/extended_gcsfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,7 @@
from gcsfs import zb_hns_utils
from gcsfs.core import GCSFile, GCSFileSystem
from gcsfs.retry import DEFAULT_RETRY_CONFIG, get_storage_control_retry_config
from gcsfs.zb_hns_utils import (
DirectMemmoveBuffer,
MRDPool,
PyBytes_AsString,
PyBytes_FromStringAndSize,
)
from gcsfs.zb_hns_utils import DirectMemmoveBuffer, MRDPool
from gcsfs.zonal_file import ZonalFile

logger = logging.getLogger("gcsfs")
Expand Down Expand Up @@ -411,39 +406,34 @@ async def _fetch_range_split(
await mrd.close()

async def _concurrent_mrd_fetch(self, offset, length, concurrency, mrd_or_pool):
"""Helper to handle concurrent chunk downloads into a DirectMemmoveBuffer."""
"""Helper to handle concurrent chunk downloads cleanly."""
concurrency = (
concurrency if length >= self.MIN_CHUNK_SIZE_FOR_CONCURRENCY else 1
)
result_bytes = PyBytes_FromStringAndSize(None, length)
buffer_ptr = PyBytes_AsString(result_bytes)

part_size = length // concurrency
tasks = []
buffers = []
loop = asyncio.get_running_loop()

# Track if the core download process failed
tasks = []
views = []
has_error = False

async def _download(o, s, b, mrd_or_pool):
# The master buffer manages its own allocation under the hood
master_buffer = DirectMemmoveBuffer(length, self._memmove_executor)

async def _download(o, s, view, mrd_or_pool):
async with _get_mrd_from_pool_or_mrd(mrd_or_pool) as m_client:
await m_client.download_ranges([(o, s, b)])
await m_client.download_ranges([(o, s, view)])

for i in range(concurrency):
part_offset = offset + (i * part_size)
actual_size = part_size if i < concurrency - 1 else length - (i * part_size)

part_address = buffer_ptr + (part_offset - offset)
buf = DirectMemmoveBuffer(
part_address,
part_address + actual_size,
self._memmove_executor,
)
buffers.append(buf)
# Give each task a restricted view of the master buffer
view = master_buffer.get_view(part_offset - offset, actual_size)
views.append(view)

tasks.append(
asyncio.create_task(
_download(part_offset, actual_size, buf, mrd_or_pool)
_download(part_offset, actual_size, view, mrd_or_pool)
)
)

Expand All @@ -453,6 +443,8 @@ async def _download(o, s, b, mrd_or_pool):
if isinstance(res, Exception):
has_error = True
raise res
for view in views:
view.close()
except BaseException:
has_error = True
for t in tasks:
Expand All @@ -461,18 +453,17 @@ async def _download(o, s, b, mrd_or_pool):
await asyncio.gather(*tasks, return_exceptions=True)
raise
finally:
for buf in buffers:
try:
await loop.run_in_executor(None, buf.close)
except BufferError:
# If we are already handling a network/download exception,
# ignore the BufferError (which is just a symptom of the drop).
# If there's no download error, this means the buffer logic
# itself failed, so we must surface the error.
if not has_error:
raise

return result_bytes
try:
master_buffer.close()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Large reads can block the asyncio loop while waiting for memmove, as it calls master_buffer.close() synchronously within the finally block of _concurrent_mrd_fetch. Since this coroutine executes directly on the main asyncio event loop thread, this causes a block.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was actually an intentional change, and I want to block the async event loop in this specific case. I know it seems like an anti-pattern, but it is required for performance in these heavily multithreaded scenarios.

When testing with 32 coroutines, I found that running master.close() synchronously was about 500 MB/s faster than using await loop.run_in_executor(None, buff.close) with SDK. Because the system is completely bombarded by ctypes.memmove operations, the overhead of offloading the close operation becomes too high.

If you look at the previous implementation, the initial write operation using io.BytesIO was synchronous and consumed the GIL. We were already blocking the event loop to write data into the buffer. With the new approach, we are simply blocking the event loop to wait for a thread to finish writing that data. The end result on the event loop is effectively the same, but with much better throughput.

I'm open for suggestions, so feel free to let me know if we can do it in a better way :) - The only condition here is that we need a synchronous def write method to a buffer (similar to io.BytesIO interface) called from async method.

Regarding the concern about large reads, I do not think it will cause problems. Downloading large blocks of data over the network will inherently take significantly longer than writing that data into the buffer.

I have included a script below for your reference. If you run it with master.close(), you will see it hits over 2 GB/s. If you switch to await loop.run_in_executor(None, master.close), it drops to around 1.4 GB/s on average.

Script
import asyncio
import argparse
import logging
import os
import random
import time
import warnings
import ctypes
import threading
import concurrent.futures
from google.cloud.storage.asyncio.async_grpc_client import AsyncGrpcClient
from google.cloud.storage.asyncio.async_multi_range_downloader import AsyncMultiRangeDownloader 

os.environ["GRPC_VERBOSITY"] = "ERROR"
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    module="google.api_core._python_version_support"
)

BUCKET_NAME = None
OBJECT_NAME = None
TOTAL_OBJECT_SIZE = 20 * 1024 * 1024 * 1024
NUM_RUNS = 5


PyBytes_FromStringAndSize = ctypes.pythonapi.PyBytes_FromStringAndSize
PyBytes_FromStringAndSize.argtypes = (ctypes.c_void_p, ctypes.c_ssize_t)
PyBytes_FromStringAndSize.restype = ctypes.py_object

PyBytes_AsString = ctypes.pythonapi.PyBytes_AsString
PyBytes_AsString.argtypes = (ctypes.py_object,)
PyBytes_AsString.restype = ctypes.c_void_p

class DirectMemmoveBuffer:
    """
    A buffer-like object that writes data directly to memory asynchronously.

    This class provides an interface that queues `ctypes.memmove` operations
    to a thread pool executor. It provides synchronous backpressure: if `max_pending`
    operations are currently writing, the `write()` call will safely block the
    calling thread (e.g., an asyncio loop) until capacity frees up.

    Memory allocation is natively deferred. If the payload precisely aligns
    with expected bounds sequentially, it gracefully overrides manual memmoves
    using true Zero-Copy payload replacement safely under the hood.

    Note: This class is now strictly Thread-Safe
    """

    THRESHOLD_BYTES_FOR_SCHEDULING = 128 * 1024

    def __init__(self, expected_size, executor, max_pending=5):
        """
        Initializes the DirectMemmoveBuffer.

        Args:
            expected_size (int): The total amount of bytes expected to populate memory.
            executor (concurrent.futures.Executor): The thread pool executor to run the
                memmove operations. The lifecycle of this executor is managed by the caller.
            max_pending (int, optional): The maximum number of pending write operations
                allowed in the queue. Defaults to 5.
        """
        self.expected_size = expected_size
        self.executor = executor

        self._pending_count = 0
        self._error = None
        self._total_bytes_written = 0
        self._stop_accepting_writes = False
        self._is_closed = False

        self._allocated_intervals = []
        self._result_bytes = None
        self._start_address = None
        self.semaphore = threading.Semaphore(max_pending)
        self._lock = threading.Lock()
        self._done_event = threading.Event()
        self._done_event.set()

    class PartialView:
        """A bounded memory writer providing robust overfill/underfill constraint validations."""

        def __init__(self, parent, start_offset, expected_size):
            self.parent = parent
            self.start_offset = start_offset
            self.expected_size = expected_size
            self.current_offset = 0
            self._view_lock = threading.Lock()

        def write(self, data):
            """
            Schedules a write operation to memory mapping.
            """
            if not isinstance(data, bytes):
                raise ValueError(f"Expected bytes, but got {type(data)}")
            size = len(data)
            with self._view_lock:
                if self.current_offset + size > self.expected_size:
                    error_msg = (
                        f"Attempted to write {size} bytes "
                        f"at offset {self.current_offset}. "
                        f"Max capacity is {self.expected_size} bytes."
                    )
                    raise BufferError(error_msg)

                abs_offset = self.start_offset + self.current_offset
                self.current_offset += size

            return self.parent._submit_write(abs_offset, data, size)

        def close(self):
            """
            Validates boundaries enforcing complete local payload consistency.
            """
            with self._view_lock:
                if self.current_offset < self.expected_size:
                    error_msg = (
                        f"Expected {self.expected_size} bytes, "
                        f"but only received {self.current_offset} bytes. "
                        f"Buffer contains uninitialized data."
                    )
                    raise BufferError(error_msg)

    def get_view(self, offset, size):
        """Constructs secure mapped offset references correctly handling constraint layouts."""
        if offset < 0 or offset + size > self.expected_size:
            raise ValueError("Invalid view requested: exceeds physical boundaries!")

        start = offset
        end = offset + size

        with self._lock:
            if self._stop_accepting_writes or self._is_closed:
                raise ValueError("Cannot get view on a closed/closing buffer.")

            for a_start, a_end in self._allocated_intervals:
                if max(start, a_start) < min(end, a_end):
                    raise ValueError(
                        f"Overlapping view requested: [{start}, {end}) "
                        f"overlaps with already allocated view [{a_start}, {a_end})"
                    )

            self._allocated_intervals.append((start, end))

        return self.PartialView(self, offset, size)

    def _decrement_pending(self):
        """Helper to cleanly release concurrency primitives after a task finishes."""
        self.semaphore.release()
        with self._lock:
            self._pending_count -= 1
            if self._pending_count == 0:
                self._done_event.set()

    def _submit_write(self, dest_offset, data_bytes, size):
        self.semaphore.acquire()

        try:
            with self._lock:
                if self._stop_accepting_writes or self._is_closed:
                    raise ValueError("I/O operation on closed buffer.")

                if self._error:
                    raise self._error

                if self._result_bytes is None:
                    if dest_offset == 0 and size == self.expected_size:
                        self._result_bytes = data_bytes
                        self.semaphore.release()  # Release because we skip the executor
                        fut = concurrent.futures.Future()
                        fut.set_result(None)
                        self._total_bytes_written += size
                        return fut
                    else:
                        self._result_bytes = PyBytes_FromStringAndSize(
                            None, self.expected_size
                        )
                        self._start_address = PyBytes_AsString(self._result_bytes)

                # Defensive programming: gracefully catch internal overwrite attempts
                if self._start_address is None:
                    raise BufferError(
                        "Attempted to execute standard write over a Zero-Copied payload."
                    )

                dest = self._start_address + dest_offset
                if self._pending_count == 0:
                    self._done_event.clear()
                self._pending_count += 1

        except BaseException:
            self.semaphore.release()
            raise

        try:
            if size <= self.THRESHOLD_BYTES_FOR_SCHEDULING:
                # Fast path, no need to send it to executor
                self._do_memmove(dest, data_bytes, size)                
                fut = concurrent.futures.Future()
                with self._lock:
                    local_err = self._error
                if local_err:
                    fut.set_exception(local_err)
                else:
                    fut.set_result(None)
                return fut
            else:
                # Slow path, schedule it on executor.
                return self.executor.submit(self._do_memmove, dest, data_bytes, size)
        except BaseException as e:
            with self._lock:
                self._error = e
            self._decrement_pending()
            raise e

    def _do_memmove(self, dest, data_bytes, size):
        try:
            with self._lock:
                if self._error:
                    return

            ctypes.memmove(dest, data_bytes, size)
            with self._lock:
                self._total_bytes_written += size

        except BaseException as e:
            with self._lock:
                if self._error is None:
                    self._error = e
            raise
        finally:
            self._decrement_pending()

    def get_value(self):
        with self._lock:
            if self._error:
                raise self._error
            if not self._is_closed:
                raise RuntimeError("Buffer is still not closed yet!")
            if self._result_bytes is None and self.expected_size == 0:
                return b""
            if self._total_bytes_written < self.expected_size:
                raise BufferError(
                    f"Buffer incomplete: Expected {self.expected_size} bytes but "
                    f"only populated {self._total_bytes_written}. Returning this "
                    f"payload would leak uninitialized memory."
                )
            return self._result_bytes

    def close(self):
        """
        Locks the buffer preventing further incoming writes, waits for all pending
        write operations to complete, and checks for errors.
        """
        with self._lock:
            self._stop_accepting_writes = True

        self._done_event.wait()
        with self._lock:
            self._is_closed = True
            if self._error:
                raise self._error


async def async_worker_task(read_size, mode, warmup_secs, run_secs):
    """
    An individual task that loops and sequentially fetches data for a specific duration.
    Each task gets its own mrd, its own thread executor, and starts from 0 sequentially.
    """
    client = AsyncGrpcClient()
    mrd = await AsyncMultiRangeDownloader.create_mrd(
        client, 
        bucket_name=BUCKET_NAME, 
        object_name=OBJECT_NAME
    )
    
    file_offset = 0
    recorded_bytes = 0
    
    loop_start = time.monotonic()
    warmup_end = loop_start + warmup_secs
    test_end = warmup_end + run_secs
    
    loop = asyncio.get_running_loop()    
    with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
        while True:
            if time.monotonic() >= test_end:
                break
                
            xread = read_size
            
            if mode == "sequential":
                if file_offset + xread > TOTAL_OBJECT_SIZE:
                    xread = TOTAL_OBJECT_SIZE - file_offset
            else:
                max_offset = max(0, TOTAL_OBJECT_SIZE - read_size)
                file_offset = random.randint(0, int(max_offset))
                xread = read_size
                
            buffer = DirectMemmoveBuffer(
                expected_size=xread,
                executor=executor
            )
            
            view = buffer.get_view(0, xread)                
            await mrd.download_ranges([(file_offset, xread, view)])            
            view.close()
            await loop.run_in_executor(None, buffer.close)
            
            if time.monotonic() >= warmup_end:
                recorded_bytes += xread
            
            if mode == "sequential":
                file_offset += xread
                if file_offset >= TOTAL_OBJECT_SIZE:
                    file_offset = 0
            
    await mrd.close()
    if hasattr(client, 'close'):
        if asyncio.iscoroutinefunction(client.close):
            await client.close()
        else:
            client.close()
    return recorded_bytes


async def run_benchmark_step(current_read_size, mode, coro_count, warmup_secs, run_secs):
    """
    Spawns `coro_count` concurrent async tasks for one test phase.
    """
    tasks = []
    for i in range(coro_count):
        task = asyncio.create_task(
            async_worker_task(current_read_size, mode, warmup_secs, run_secs)
        )
        tasks.append(task)
        
    results = await asyncio.gather(*tasks)
    total_bytes = sum(results)
    
    total_gb = total_bytes / (1024**3)
    gb_per_sec = total_gb / run_secs if run_secs > 0 else 0
    return gb_per_sec


def async_main(args, chunk_size_bytes):
    print(f"Testing: {args.mb} MB chunks")
    
    print(f"   Warmup run (establishing TCP connections)...", end="", flush=True)
    warmup_throughput = asyncio.run(run_benchmark_step(chunk_size_bytes, args.mode, args.coro_count, 0, args.warmup))
    print(f" Done. (Ignored: {warmup_throughput:.2f} GB/s)")

    run_values = []
    for run_idx in range(1, NUM_RUNS + 1):
        print(f"   Run {run_idx}/{NUM_RUNS} (Benchmarking)...", end="", flush=True)
        throughput = asyncio.run(run_benchmark_step(chunk_size_bytes, args.mode, args.coro_count, 0, args.time))
        run_values.append(throughput)
        print(f" Done. ({throughput:.2f} GB/s)")

    return run_values


def main():
    parser = argparse.ArgumentParser(description="AsyncMultiRangeDownloader Benchmark")
    parser.add_argument(
        "--mode", 
        choices=["sequential", "random"], 
        default="sequential", 
        help="Choose the read pattern: sequential (default) or random"
    )
    parser.add_argument(
        "--mb", 
        type=int,
        required=True,
        help="Size of each chunk read in MB (e.g., 1 for 1MB IO size)"
    )
    parser.add_argument(
        "--type", 
        choices=["zonal", "regional"], 
        default="zonal", 
        help="Choose the bucket type"
    )
    parser.add_argument(
        "--coro-count", 
        type=int, 
        default=4, 
        help="Number of concurrent asyncio tasks issuing read ranges"
    )
    parser.add_argument(
        "--time", 
        type=int, 
        default=30, 
        help="Duration of the recorded benchmark in seconds (default: 30)"
    )
    parser.add_argument(
        "--warmup", 
        type=int, 
        default=5, 
        help="Duration of the warmup period in seconds before counting bytes (default: 5)"
    )

    args = parser.parse_args()

    global BUCKET_NAME, OBJECT_NAME
    if args.type == 'zonal':
        BUCKET_NAME = "do-not-delete-margubur-zonal-usc1"
        OBJECT_NAME = "20gb-file"
    else:
        BUCKET_NAME = "rahman-bucket"
        OBJECT_NAME = "100gb-file"

    try:
        os.sched_setaffinity(0, {i for i in range(20, 180)})
    except AttributeError:
        pass 

    chunk_size_bytes = args.mb * 1024 * 1024

    print(f"Starting AsyncMultiRangeDownloader (gRPC Bidi-Stream) Benchmark")
    print(f"Read Mode: {args.mode.upper()}")
    print(f"Config: {args.time}s test time + {args.warmup}s network warmup | {NUM_RUNS} Runs | {args.coro_count} Tasks")
    print("-" * 80)

    run_values = async_main(args, chunk_size_bytes)
    
    best_result = max(run_values)
    print(f" >> Best for this config: {best_result:.2f} GB/s\n")

    print("\n" + "=" * 80)
    print(f"FINAL RESULTS TABLE ({args.mode.upper()} - Best of {NUM_RUNS} Runs)")
    print("=" * 80)
    
    header = f"{'CHUNK SIZE':<15} | {'ASYNC TASKS':<15} | THROUGHPUT"
    print(header)
    print("-" * 80)

    row = f"{args.mb}MB".ljust(15) + f" | {args.coro_count} TASKS".ljust(18) + f" | {best_result:.2f} GB/s"
    print(row)
    print("=" * 80)

if __name__ == "__main__":
    logging.basicConfig(level=logging.ERROR)
    main()
Numbers with loop.run_in_executor(None, master.close)
(temp) margubur@grpc-team-test:~/scripts$ python3 bench-direct.py --coro-count=32 --mb=1
Starting AsyncMultiRangeDownloader (gRPC Bidi-Stream) Benchmark
Read Mode: SEQUENTIAL
Config: 30s test time + 5s network warmup | 5 Runs | 32 Tasks
--------------------------------------------------------------------------------
Testing: 1 MB chunks
   Warmup run (establishing TCP connections)... Done. (Ignored: 1.26 GB/s)
   Run 1/5 (Benchmarking)... Done. (1.23 GB/s)
   Run 2/5 (Benchmarking)... Done. (1.68 GB/s)
   Run 3/5 (Benchmarking)... Done. (1.26 GB/s)
   Run 4/5 (Benchmarking)... Done. (1.43 GB/s)
   Run 5/5 (Benchmarking)... Done. (1.25 GB/s)
 >> Best for this config: 1.68 GB/s
Numbers for master.close
(temp) margubur@grpc-team-test:~/scripts$ python3 bench-direct.py --coro-count=32 --mb=1
Starting AsyncMultiRangeDownloader (gRPC Bidi-Stream) Benchmark
Read Mode: SEQUENTIAL
Config: 30s test time + 5s network warmup | 5 Runs | 32 Tasks
--------------------------------------------------------------------------------
Testing: 1 MB chunks
   Warmup run (establishing TCP connections)... Done. (Ignored: 2.11 GB/s)
   Run 1/5 (Benchmarking)... Done. (1.97 GB/s)
   Run 2/5 (Benchmarking)... Done. (2.01 GB/s)
   Run 3/5 (Benchmarking)... Done. (2.05 GB/s)
   Run 4/5 (Benchmarking)... Done. (2.11 GB/s)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhixiangli : are you satisfied?

except Exception:
# If we are already handling a network/download exception,
# ignore the exception from buffer (which is just a symptom of the drop).
# If there's no download error, this means the buffer logic
# itself failed, so we must surface the error.
if not has_error:
raise
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are multiple errors possible?


return master_buffer.get_value()

async def _cat_file(
self,
Expand Down
255 changes: 150 additions & 105 deletions gcsfs/tests/test_zb_hns_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import concurrent.futures
import ctypes
import logging
from unittest import mock

Expand Down Expand Up @@ -380,104 +379,6 @@ async def test_mrd_pool_close_with_exceptions(create_mrd_mock, mock_gcsfs):
assert len(pool._all_mrds) == 0


@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove")
def test_direct_memmove_buffer_error_handling(mock_memmove):
size = 20
buffer_array = (ctypes.c_char * size)()
start_address = ctypes.addressof(buffer_array)
end_address = start_address + size

# Simulate an access violation or similar error during memory copy
mock_memmove.side_effect = MemoryError("Segfault simulated")

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2)

# First write triggers the background error
future = buf.write(b"bad data")

# Wait for the background thread to actually fail
with pytest.raises(MemoryError):
future.result()

# Subsequent writes should raise the stored error immediately
with pytest.raises(MemoryError, match="Segfault simulated"):
buf.write(b"more data")

# Close should also raise the stored error.
with pytest.raises(MemoryError, match="Segfault simulated"):
buf.close()

executor.shutdown()


def test_direct_memmove_buffer():
data1 = b"hello"
data2 = b"world"

# Calculate exact size to prevent the new underflow check from failing
size = len(data1) + len(data2)
buffer_array = (ctypes.c_char * size)()
start_address = ctypes.addressof(buffer_array)
end_address = start_address + size

executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2)

future1 = buf.write(data1)
future2 = buf.write(data2)

future1.result()
future2.result()
buf.close()

result_bytes = ctypes.string_at(start_address, len(data1) + len(data2))
assert result_bytes == b"helloworld"

executor.shutdown()


def test_direct_memmove_buffer_overflow():
"""Tests that writing past the allocated end_address raises a BufferError."""
size = 10
buffer_array = (ctypes.c_char * size)()
start_address = ctypes.addressof(buffer_array)
end_address = start_address + size

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2)

# Fill the buffer exactly to capacity
buf.write(b"1234567890")

# Attempting to write even 1 more byte should trigger the overflow protection
with pytest.raises(BufferError, match="Attempted to write"):
buf.write(b"1")

buf.close()
executor.shutdown()


def test_direct_memmove_buffer_underflow():
"""Tests that closing an incompletely filled buffer raises a BufferError."""
size = 10
buffer_array = (ctypes.c_char * size)()
start_address = ctypes.addressof(buffer_array)
end_address = start_address + size

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2)

# Write fewer bytes than the expected capacity
buf.write(b"12345")

# Closing should detect that current_offset (5) < expected size (10)
with pytest.raises(BufferError, match="Buffer contains uninitialized data"):
buf.close()

executor.shutdown()


@pytest.mark.asyncio
async def test_mrd_pool_queue_filled_during_lock_wait(mock_gcsfs):
pool = MRDPool(mock_gcsfs, "bucket", "obj", "123", pool_size=1)
Expand Down Expand Up @@ -547,28 +448,129 @@ async def fake_create_mrd():
assert pool._rr_index == 1


@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove")
def test_direct_memmove_buffer_error_handling(mock_memmove):
# Use a size > 128KB to trigger the executor background path
size = 130 * 1024 + 10
data1 = b"a" * (130 * 1024)
data2 = b"b" * 10

# Simulate an access violation or similar error during memory copy
mock_memmove.side_effect = MemoryError("Segfault simulated")

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(size, executor, max_pending=2)
view = buf.get_view(0, size)

# First write triggers the background error (slow path)
future = view.write(data1)

# Wait for the background thread to actually fail
with pytest.raises(MemoryError):
future.result()

# Subsequent writes should raise the stored error immediately
with pytest.raises(MemoryError, match="Segfault simulated"):
view.write(data2)

# Close should also raise the stored error.
with pytest.raises(MemoryError, match="Segfault simulated"):
buf.close()

executor.shutdown()


def test_direct_memmove_buffer():
data1 = b"hello"
data2 = b"world"
size = len(data1) + len(data2)

executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
buf = DirectMemmoveBuffer(size, executor, max_pending=2)
view = buf.get_view(0, size)

future1 = view.write(data1)
future2 = view.write(data2)

future1.result()
future2.result()

view.close()
buf.close()

result_bytes = buf.get_value()
assert result_bytes == b"helloworld"

executor.shutdown()


def test_direct_memmove_buffer_overflow():
"""Tests that writing past the view boundaries raises a BufferError."""
size = 10
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(size, executor, max_pending=2)
view = buf.get_view(0, size)

# Fill the buffer exactly to capacity
view.write(b"1234567890")

# Attempting to write even 1 more byte should trigger the overflow protection
with pytest.raises(BufferError, match="Attempted to write"):
view.write(b"1")

view.close()
buf.close()
executor.shutdown()


def test_direct_memmove_buffer_underflow():
"""Tests that closing an incompletely filled view/buffer raises a BufferError."""
size = 10
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(size, executor, max_pending=2)
view = buf.get_view(0, size)

# Write fewer bytes than the expected capacity
view.write(b"12345")

# Closing the view should detect that current_offset (5) < expected size (10)
with pytest.raises(BufferError, match="Buffer contains uninitialized data"):
view.close()

# Calling get_value after an incompletely filled buffer should also error
buf.close()
with pytest.raises(BufferError, match="Buffer incomplete"):
buf.get_value()

executor.shutdown()


@mock.patch("gcsfs.zb_hns_utils.ctypes.memmove")
def test_direct_memmove_buffer_submit_failure(mock_memmove):
"""
Tests that if executor.submit fails synchronously (e.g., executor is closed),
the internal locks, semaphores, and events are properly reset, and close()
does not hang.
"""
size = 10
buffer_array = (ctypes.c_char * size)()
start_address = ctypes.addressof(buffer_array)
end_address = start_address + size
# 1. Chunk > 128KB to force executor scheduling (skip the synchronous fast path)
chunk_size = 130 * 1024

# 2. Expected size > chunk_size to skip the Zero-Copy optimization
expected_size = 140 * 1024

data = b"a" * chunk_size

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(start_address, end_address, executor, max_pending=2)
buf = DirectMemmoveBuffer(expected_size, executor, max_pending=2)
view = buf.get_view(0, expected_size)

# Mock the submit method to simulate a closed executor throwing a RuntimeError
with mock.patch.object(
executor, "submit", side_effect=RuntimeError("Executor closed")
):
# The write operation should raise the simulated RuntimeError
with pytest.raises(RuntimeError, match="Executor closed"):
buf.write(b"12345")
view.write(data)

# Verify that the internal tracking state was correctly rolled back
assert buf._pending_count == 0
Expand All @@ -579,3 +581,46 @@ def test_direct_memmove_buffer_submit_failure(mock_memmove):
buf.close()

executor.shutdown()


def test_direct_memmove_buffer_zero_copy():
"""Tests that a perfect aligned single payload avoids memory allocation completely."""
data = b"exact_size_payload"
size = len(data)

executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(size, executor, max_pending=2)
view = buf.get_view(0, size)

# Writing a single payload identical to the expected size
future = view.write(data)
future.result()

view.close()
buf.close()

# Should be the EXACT same string object returned without copying
result = buf.get_value()
assert result is data

executor.shutdown()


def test_direct_memmove_buffer_overlapping_views():
"""Tests that getting overlapping views raises a ValueError."""
size = 100
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
buf = DirectMemmoveBuffer(size, executor, max_pending=2)

# Get a view for the first half
_ = buf.get_view(0, 50)

# Attempting to get an overlapping view should fail
with pytest.raises(ValueError, match="Overlapping view requested"):
_ = buf.get_view(25, 50)

# Getting a view for the second half should succeed
_ = buf.get_view(50, 50)

buf.close()
executor.shutdown()
Loading
Loading