Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gcsfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def project(self):
# cleanup (which can handle cross-thread calls).
@staticmethod
def close_session(loop, session: aiohttp.ClientSession, asynchronous=False):
if session.closed:
if session is None or session.closed:
return
force_close = False
try:
Expand Down
123 changes: 86 additions & 37 deletions gcsfs/extended_gcsfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
from glob import has_magic

import fsspec
Comment thread
Yonghui-Lee marked this conversation as resolved.
from fsspec import asyn
from fsspec.callbacks import NoOpCallback
from google.api_core import exceptions as api_exceptions
Expand Down Expand Up @@ -89,7 +90,9 @@ class ExtendedGcsFileSystem(GCSFileSystem):
to the parent class GCSFileSystem for default processing.
"""

def __init__(self, *args, finalize_on_close=False, **kwargs):
def __init__(
self, *args, finalize_on_close=False, mrd_pool_cache_size=16, **kwargs
):
"""
Parameters
----------
Expand Down Expand Up @@ -130,6 +133,36 @@ def __init__(self, *args, finalize_on_close=False, **kwargs):
max_workers=kwargs.get("memmove_max_workers", 8)
)
weakref.finalize(self, self._memmove_executor.shutdown)
self._mrd_pool_cache = zb_hns_utils.MRDPoolCache(
self, max_idle_pools=mrd_pool_cache_size
)
weakref.finalize(
self,
self._finalize_mrd_pool_cache,
self.loop,
self._mrd_pool_cache,
)

@staticmethod
def _finalize_mrd_pool_cache(loop, cache):
"""Tear down the MRDPoolCache when ExtendedGcsFileSystem is garbage collected."""
if cache is None or getattr(cache, "_closed", False):
return

try:
current_loop = asyncio.get_running_loop()
except RuntimeError:
current_loop = None

if loop and loop.is_running():
asyncio.run_coroutine_threadsafe(cache.close(), loop)
elif current_loop is not None and current_loop.is_running():
asyncio.run_coroutine_threadsafe(cache.close(), current_loop)
elif asyn.loop[0] is not None and asyn.loop[0].is_running():
try:
asyn.sync(asyn.loop[0], cache.close, timeout=0.1)
Comment thread
Yonghui-Lee marked this conversation as resolved.
Outdated
except fsspec.FSTimeoutError:
pass

@property
def _user_project(self):
Expand Down Expand Up @@ -190,6 +223,26 @@ async def _get_control_plane_client(self):
)
return self._storage_control_client

async def close_resources(self):
"""Close gRPC clients, channels, and other resources."""
if self._grpc_client is not None:
try:
await self._grpc_client.grpc_client.transport.close()
except Exception as e:
logger.warning(f"Failed to close grpc_client: {e}")
self._grpc_client = None
if self._storage_control_client is not None:
try:
await self._storage_control_client.transport.close()
except Exception as e:
logger.warning(f"Failed to close storage_control_client: {e}")
self._storage_control_client = None
if self._mrd_pool_cache is not None:
try:
await self._mrd_pool_cache.close()
except Exception as e:
logger.warning(f"Failed to close MRDPoolCache: {e}")

async def _lookup_bucket_type(self, bucket):
if bucket in self._storage_layout_cache:
return self._storage_layout_cache[bucket]
Expand Down Expand Up @@ -358,10 +411,9 @@ async def _fetch_range_split(
if mrd is None:
# If no mrd is provided, we create one with pool size equal to passed concurrency.
pool_size = min(len(chunk_lengths), concurrency)
mrd = zb_hns_utils.MRDPool(
self, bucket, object_name, generation, pool_size=pool_size
mrd = await self._mrd_pool_cache.get(
bucket, object_name, generation, pool_size=pool_size
)
await mrd.initialize()
pool_created_here = True

tasks = []
Expand Down Expand Up @@ -511,10 +563,9 @@ async def _cat_file(
)

# Instantiate an MRDPool locally for this call
mrd = zb_hns_utils.MRDPool(
self, bucket, object_name, generation, pool_size=concurrency
mrd = await self._mrd_pool_cache.get(
bucket, object_name, generation, pool_size=concurrency
)
await mrd.initialize()
pool_created_here = True

try:
Expand Down Expand Up @@ -1570,48 +1621,46 @@ async def _get_file(self, rpath, lpath, callback=None, **kwargs):
return
callback = callback or NoOpCallback()

mrd = None
mrd_pool = await self._mrd_pool_cache.get(bucket, key, generation, pool_size=1)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a pool size of 1, the pool cache initializes a new pool every time, resulting in no performance improvements and added overhead.

Do you have any measurements showing the improvements made?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to this
@googlyrahman @Yonghui-Lee do we have any macro benchmark results which confirm not just this change but overall MRD pooling is significantly improving performance ? it ll be good to add in description

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a micro benchmark to show the improvement of read throughput. For the current tessellation macro benchmark, I think pool cache may not improve it because we don't open the same file multiple times in the same process.

image

try:
await self._get_grpc_client()
mrd = await zb_hns_utils.init_mrd(self.grpc_client, bucket, key, generation)

size = mrd.persisted_size
if size is None:
logger.warning(
f"AsyncMultiRangeDownloader (MRD) for {rpath} has no 'persisted_size'. "
"Falling back to _info() to get the file size. "
"This may result in incorrect behavior for unfinalized objects."
)
size = (await self._info(rpath))["size"]
callback.set_size(size)
async with mrd_pool.get_mrd() as mrd:
size = mrd.persisted_size
Comment thread
Yonghui-Lee marked this conversation as resolved.
if size is None:
logger.warning(
f"AsyncMultiRangeDownloader (MRD) for {rpath} has no 'persisted_size'. "
"Falling back to _info() to get the file size. "
"This may result in incorrect behavior for unfinalized objects."
)
size = (await self._info(rpath))["size"]
callback.set_size(size)

lparent = os.path.dirname(lpath) or os.curdir
os.makedirs(lparent, exist_ok=True)
lparent = os.path.dirname(lpath) or os.curdir
os.makedirs(lparent, exist_ok=True)

chunksize = kwargs.get("chunksize", 4096 * 32) # 128KB default
offset = 0
chunksize = kwargs.get("chunksize", 4096 * 32) # 128KB default
offset = 0

with open(lpath, "wb") as f2:
while True:
if offset >= size:
break
with open(lpath, "wb") as f2:
while True:
if offset >= size:
break

data = await zb_hns_utils.download_range(
offset=offset, length=chunksize, mrd=mrd
)
if not data:
break
data = await zb_hns_utils.download_range(
offset=offset, length=chunksize, mrd=mrd
)
if not data:
break

f2.write(data)
offset += len(data)
callback.relative_update(len(data))
f2.write(data)
offset += len(data)
callback.relative_update(len(data))
except Exception as e:
# Clean up the corrupted file before raising error
if os.path.exists(lpath):
os.remove(lpath)
raise e
finally:
await zb_hns_utils.close_mrd(mrd)
await mrd_pool.close()

async def _do_list_objects(
self,
Expand Down
25 changes: 24 additions & 1 deletion gcsfs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
import pytest_asyncio
import requests
from fsspec import asyn
from google.cloud import storage
from google.cloud.storage.asyncio.async_appendable_object_writer import (
AsyncAppendableObjectWriter,
Expand Down Expand Up @@ -246,6 +247,7 @@ def gcs(gcs_factory, buckets_to_delete, populate_bucket):
yield gcs
finally:
_cleanup_gcs(gcs, bucket_populated=populate_bucket)
_close_gcs(gcs)
# Remove the dynamically added attribute. This prevents state leakage
# into subsequent tests that can share this cached fsspec instance.
if hasattr(gcs, "finalize_on_close"):
Expand All @@ -267,6 +269,7 @@ def factory(**kwargs):

for fs in created_instances:
_cleanup_gcs(fs, bucket_populated=populate_bucket)
_close_gcs(fs)


@pytest.fixture
Expand All @@ -278,6 +281,7 @@ def extended_gcsfs(gcs_factory, buckets_to_delete, populate_bucket):
yield extended_gcsfs
finally:
_cleanup_gcs(extended_gcsfs, bucket_populated=populate_bucket)
_close_gcs(extended_gcsfs)


def _cleanup_gcs(gcs, bucket=TEST_BUCKET, bucket_populated=True):
Expand All @@ -291,6 +295,20 @@ def _cleanup_gcs(gcs, bucket=TEST_BUCKET, bucket_populated=True):
logging.warning(f"Failed to clean up GCS bucket {bucket}: {e}")


def _close_gcs(gcs):
"""Close gcs instance resources for sync fixtures."""
if hasattr(gcs, "close_resources"):
asyn.sync(gcs.loop, gcs.close_resources)
GCSFileSystem.close_session(gcs.loop, gcs._session, gcs.asynchronous)


async def _close_gcs_async(gcs):
"""Close gcs instance resources for async fixtures."""
if hasattr(gcs, "close_resources"):
await gcs.close_resources()
GCSFileSystem.close_session(gcs.loop, gcs._session, gcs.asynchronous)


@pytest.fixture(scope="session", autouse=True)
def final_cleanup(gcs_factory, buckets_to_delete):
"""
Expand Down Expand Up @@ -351,6 +369,7 @@ def gcs_versioned(gcs_factory, buckets_to_delete):
logging.warning(
f"Failed to clean up versioned bucket {TEST_VERSIONED_BUCKET} after test: {e}"
)
_close_gcs(gcs)


def cleanup_versioned_bucket(gcs, bucket_name, prefix=None):
Expand Down Expand Up @@ -447,6 +466,7 @@ def gcs_hns(gcs_factory, buckets_to_delete):
yield gcs
finally:
_cleanup_gcs(gcs, bucket=TEST_HNS_BUCKET)
_close_gcs(gcs)


@pytest.fixture
Expand Down Expand Up @@ -516,7 +536,10 @@ async def async_gcs():
token = "anon" if not os.getenv("STORAGE_EMULATOR_HOST") else None
GCSFileSystem.clear_instance_cache()
gcs = GCSFileSystem(asynchronous=True, token=token)
yield gcs
try:
yield gcs
finally:
await _close_gcs_async(gcs)


def pytest_addoption(parser):
Expand Down
Loading
Loading