diff --git a/src/aiida/engine/processes/calcjobs/manager.py b/src/aiida/engine/processes/calcjobs/manager.py index 4db2c97b0c..21c77ff958 100644 --- a/src/aiida/engine/processes/calcjobs/manager.py +++ b/src/aiida/engine/processes/calcjobs/manager.py @@ -96,7 +96,7 @@ async def _get_jobs_from_scheduler(self) -> Dict[str, 'JobInfo']: :return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances """ - with self._transport_queue.request_transport(self._authinfo) as request: + async with self._transport_queue.request_transport(self._authinfo) as request: self.logger.info('waiting for transport') transport = await request diff --git a/src/aiida/engine/processes/calcjobs/tasks.py b/src/aiida/engine/processes/calcjobs/tasks.py index 4c9e5c8b8e..ce6569961e 100644 --- a/src/aiida/engine/processes/calcjobs/tasks.py +++ b/src/aiida/engine/processes/calcjobs/tasks.py @@ -83,7 +83,7 @@ async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, c authinfo = node.get_authinfo() async def do_upload(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) with SandboxFolder(filepath_sandbox) as folder: @@ -144,7 +144,7 @@ async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, ca authinfo = node.get_authinfo() async def do_submit(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) return execmanager.submit_calculation(node, transport) @@ -252,7 +252,7 @@ async def task_monitor_job( authinfo = node.get_authinfo() async def do_monitor(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) return monitors.process(node, transport) @@ -298,7 +298,7 @@ async def task_retrieve_job( authinfo = node.get_authinfo() async def do_retrieve(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) # Perform the job accounting and set it on the node if successful. If the scheduler does not implement this # still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the @@ -366,7 +366,7 @@ async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, can authinfo = node.get_authinfo() async def do_stash(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) logger.info(f'stashing calculation<{node.pk}>') @@ -405,7 +405,7 @@ async def task_unstash_job(node: CalcJobNode, transport_queue: TransportQueue, c authinfo = node.get_authinfo() async def do_unstash(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) logger.info(f'unstashing calculation<{node.pk}>') @@ -454,7 +454,7 @@ async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, canc authinfo = node.get_authinfo() async def do_kill(): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await cancellable.with_interrupt(request) return execmanager.kill_calculation(node, transport) diff --git a/src/aiida/engine/transports.py b/src/aiida/engine/transports.py index 9b7e14fe83..e71b4ab102 100644 --- a/src/aiida/engine/transports.py +++ b/src/aiida/engine/transports.py @@ -13,7 +13,7 @@ import contextvars import logging import traceback -from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional +from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Dict, Hashable, Optional from aiida.orm import AuthInfo @@ -53,14 +53,14 @@ def loop(self) -> asyncio.AbstractEventLoop: """Get the loop being used by this transport queue""" return self._loop - @contextlib.contextmanager - def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable['Transport']]: + @contextlib.asynccontextmanager + async def request_transport(self, authinfo: AuthInfo) -> AsyncIterator[Awaitable['Transport']]: """Request a transport from an authinfo. Because the client is not allowed to request a transport immediately they will instead be given back a future that can be awaited to get the transport:: async def transport_task(transport_queue, authinfo): - with transport_queue.request_transport(authinfo) as request: + async with transport_queue.request_transport(authinfo) as request: transport = await request # Do some work with the transport @@ -78,13 +78,14 @@ async def transport_task(transport_queue, authinfo): transport = authinfo.get_transport() safe_open_interval = transport.get_safe_open_interval() - def do_open(): - """Actually open the transport""" + async def do_open(): + """Wait for safe interval, then open the transport.""" + await asyncio.sleep(safe_open_interval) if transport_request.count > 0: # The user still wants the transport so open it _LOGGER.debug('Transport request opening transport for %s', authinfo) try: - transport.open() + await transport.open_async() except Exception as exception: _LOGGER.error('exception occurred while trying to open transport:\n %s', exception) transport_request.future.set_exception(exception) @@ -99,8 +100,9 @@ def do_open(): # passed around to many places, including outside aiida-core (e.g. paramiko). Anyone keeping a reference # to this handle would otherwise keep the Process context (and thus the process itself) in memory. # See https://github.com/aiidateam/aiida-core/issues/4698 - open_callback_handle = self._loop.call_later(safe_open_interval, do_open, context=contextvars.Context()) - + empty_ctx = contextvars.Context() + open_callback_handle = empty_ctx.run(self._loop.create_task, do_open()) + # self._loop.create_task supports passing a context but only after Python 3.11+ try: transport_request.count += 1 yield transport_request.future @@ -118,18 +120,13 @@ def do_open(): # Check if there are no longer any users that want the transport if transport_request.count == 0: # IMPORTANT: Pop from _transport_requests BEFORE closing the transport. - # This prevents a race condition with async transports where: - # 1. close() is called, which for AsyncTransport uses run_until_complete(close_async) - # 2. With nest_asyncio (used by plumpy), this call yields back to the event loop - # 3. The event loop schedules close_async, then continues running another tasks - # - for example one that requests the transport which is scheduled to be closed - # 4. The task now using the transport to do some operation awaits, - # next the close_async task closes the transport while still in use -> error - # By poping first, new tasks will create a fresh transport request. + # This prevents a race condition where a new task could get a reference + # to a transport that is being closed. By popping first, new tasks will + # create a fresh transport request. self._transport_requests.pop(authinfo.pk, None) if transport_request.future.done(): _LOGGER.debug('Transport request closing transport for %s', authinfo) - transport_request.future.result().close() + await transport_request.future.result().close_async() elif open_callback_handle is not None: open_callback_handle.cancel() diff --git a/src/aiida/transports/transport.py b/src/aiida/transports/transport.py index cbe4c38db7..e60034bf33 100644 --- a/src/aiida/transports/transport.py +++ b/src/aiida/transports/transport.py @@ -9,6 +9,7 @@ """Transport interface.""" import abc +import asyncio import fnmatch import os import re @@ -127,6 +128,7 @@ def __init__(self, *args, **kwargs): self._logger_extra = None self._is_open = False self._enters = 0 + self._open_lock = asyncio.Lock() # for accessing the identity of the underlying machine self.hostname = kwargs.get('machine') @@ -171,21 +173,58 @@ def __enter__(self): """ # Keep track of how many times enter has been called + if self._track_enter(): + self.open() + return self + + def __exit__(self, type_, value, traceback): + """Closes connections, if needed (used in 'with' statements).""" + if self._track_exit(): + self.close() + + async def __aenter__(self): + """Async context manager entry. Opens the transport connection. + + For sync transports, this just calls the sync open() method. + AsyncTransport subclasses override this to use async open. + """ + async with self._open_lock: + if self._track_enter(): + self.open() + return self + + async def __aexit__(self, type_, value, traceback): + """Async context manager exit. Closes the transport connection if needed. + + For sync transports, this just calls the sync close() method. + AsyncTransport subclasses override this to use async close. + """ + async with self._open_lock: + if self._track_exit(): + self.close() + + def _track_enter(self) -> bool: + """Track a context manager entry and return whether open() needs to be called. + + Manages the ``_enters`` reference counter. If the transport is already open + (e.g. opened externally), an extra count is added so the final exit won't close it. + + Note: The caller must call ``open()``/``open_async()`` immediately when this returns True. + In async context, use ``self._open_lock`` to prevent potential concurrent open/close races. + """ + need_open = False if self._enters == 0: if self.is_open: - # Already open, so just add one to the entered counter - # this way on the final exit we will not close self._enters += 1 else: - self.open() + need_open = True self._enters += 1 - return self + return need_open - def __exit__(self, type_, value, traceback): - """Closes connections, if needed (used in 'with' statements).""" + def _track_exit(self) -> bool: + """Track a context manager exit and return whether close() needs to be called.""" self._enters -= 1 - if self._enters == 0: - self.close() + return self._enters == 0 @property def is_open(self): @@ -1857,6 +1896,19 @@ def open(self): def close(self): return self.run_command_blocking(self.close_async) + async def __aenter__(self): + """Async context manager entry. Opens the transport connection.""" + async with self._open_lock: + if self._track_enter(): + await self.open_async() + return self + + async def __aexit__(self, type_, value, traceback): + """Async context manager exit. Closes the transport connection if needed.""" + async with self._open_lock: + if self._track_exit(): + await self.close_async() + def get(self, *args, **kwargs): return self.run_command_blocking(self.get_async, *args, **kwargs) diff --git a/tests/engine/daemon/test_execmanager.py b/tests/engine/daemon/test_execmanager.py index a062f71247..caed13542c 100644 --- a/tests/engine/daemon/test_execmanager.py +++ b/tests/engine/daemon/test_execmanager.py @@ -179,7 +179,7 @@ async def test_upload_local_copy_list( node, calc_info = node_and_calc_info calc_info.local_copy_list = [[folder.uuid] + local_copy_list] - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) # Check that none of the files were written to the repository of the calculation node, since they were communicated @@ -217,7 +217,7 @@ async def test_upload_local_copy_list_files_folders( (inputs['folder'].uuid, None, '.'), ] - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) # Check that none of the files were written to the repository of the calculation node, since they were communicated @@ -249,7 +249,7 @@ async def test_upload_remote_symlink_list( (node.computer.uuid, str(tmp_path / 'file_a.txt'), 'file_a.txt'), ] - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) filepath_workdir = pathlib.Path(node.get_remote_workdir()) @@ -314,7 +314,7 @@ async def test_upload_file_copy_operation_order(node_and_calc_info, tmp_path, or if order is not None: calc_info.file_copy_operation_order = order - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation(node, transport, calc_info, sandbox, inputs) filepath = pathlib.Path(node.get_remote_workdir()) / 'file.txt' assert filepath.is_file() @@ -614,14 +614,14 @@ async def test_upload_combinations( ) if expected_exception is not None: with pytest.raises(expected_exception): - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) filepath_workdir = pathlib.Path(node.get_remote_workdir()) assert serialize_file_hierarchy(filepath_workdir, read_bytes=False) == expected_hierarchy else: - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox) filepath_workdir = pathlib.Path(node.get_remote_workdir()) @@ -650,7 +650,7 @@ async def test_upload_calculation_portable_code(fixture_sandbox, node_and_calc_i code_info.code_uuid = code.uuid calc_info.codes_info = [code_info] - with node.computer.get_transport() as transport: + async with node.computer.get_transport() as transport: await execmanager.upload_calculation( node, transport, diff --git a/tests/engine/test_transport.py b/tests/engine/test_transport.py index 2d11eafc6b..786ec39d3d 100644 --- a/tests/engine/test_transport.py +++ b/tests/engine/test_transport.py @@ -32,7 +32,7 @@ def test_simple_request(self): async def test(): trans = None - with queue.request_transport(self.authinfo) as request: + async with queue.request_transport(self.authinfo) as request: trans = await request assert trans.is_open assert not trans.is_open @@ -45,10 +45,10 @@ def test_get_transport_nested(self): loop = transport_queue.loop async def nested(queue, authinfo): - with queue.request_transport(authinfo) as request1: + async with queue.request_transport(authinfo) as request1: trans1 = await request1 assert trans1.is_open - with queue.request_transport(authinfo) as request2: + async with queue.request_transport(authinfo) as request2: trans2 = await request2 assert trans1 is trans2 assert trans2.is_open @@ -61,7 +61,7 @@ def test_get_transport_interleaved(self): loop = transport_queue.loop async def interleaved(authinfo): - with transport_queue.request_transport(authinfo) as trans_future: + async with transport_queue.request_transport(authinfo) as trans_future: await trans_future loop.run_until_complete(asyncio.gather(interleaved(self.authinfo), interleaved(self.authinfo))) @@ -72,7 +72,7 @@ def test_return_from_context(self): loop = queue.loop async def test(): - with queue.request_transport(self.authinfo) as request: + async with queue.request_transport(self.authinfo) as request: trans = await request return trans.is_open @@ -85,7 +85,7 @@ def test_open_fail(self): loop = queue.loop async def test(): - with queue.request_transport(self.authinfo) as request: + async with queue.request_transport(self.authinfo) as request: await request def broken_open(trans): @@ -119,7 +119,7 @@ def test_safe_interval(self): async def test(iteration): trans = None - with queue.request_transport(self.authinfo) as request: + async with queue.request_transport(self.authinfo) as request: trans = await request time_current = time.time() time_elapsed = time_current - time_start @@ -133,15 +133,15 @@ async def test(iteration): transport_class._DEFAULT_SAFE_OPEN_INTERVAL = original_interval def test_request_removed_before_close(self): - """Test that transport_request is removed from dict before close() is called. + """Test that transport_request is removed from dict before close_async() is called. This is a regression test for a race condition with async transports where: - 1. close() is called, which for AsyncTransport uses run_until_complete() - 2. With nest_asyncio (used by plumpy), this can yield to the event loop + 1. close_async() is called during context manager cleanup + 2. This can yield back to the event loop 3. Another task might enter and get the same transport_request 4. That task tries to use the transport that's being closed -> error - The fix ensures transport_request is removed BEFORE close(), so new tasks + The fix ensures transport_request is removed BEFORE close_async(), so new tasks create fresh transport requests. """ queue = TransportQueue() @@ -150,24 +150,24 @@ def test_request_removed_before_close(self): events = [] # Track order of operations async def test(): - with queue.request_transport(self.authinfo) as request: + async with queue.request_transport(self.authinfo) as request: trans = await request - # Patch close to track when it's called - original_close = trans.close + # Patch close_async to track when it's called + original_close_async = trans.close_async - def mock_close(): + async def mock_close_async(): # Check if request was already removed from dict if self.authinfo.pk not in queue._transport_requests: events.append('pop_before_close') events.append('close') - return original_close() + return await original_close_async() - trans.close = mock_close + trans.close_async = mock_close_async loop.run_until_complete(test()) - assert 'close' in events, 'Transport close() should have been called' - assert 'pop_before_close' in events, 'Transport request should be removed before close() is called' + assert 'close' in events, 'Transport close_async() should have been called' + assert 'pop_before_close' in events, 'Transport request should be removed before close_async() is called' assert events.index('pop_before_close') < events.index('close'), 'pop should happen before close' def test_new_request_during_close_gets_fresh_transport(self): @@ -192,7 +192,7 @@ def test_new_request_during_close_gets_fresh_transport(self): async def use_transport(task_id): # Before requesting, check if there's an existing request in the queue had_existing_request = self.authinfo.pk in queue._transport_requests - with queue.request_transport(self.authinfo) as request: + async with queue.request_transport(self.authinfo) as request: trans = await request transport_states.append( { diff --git a/tests/transports/test_asyncssh_plugin.py b/tests/transports/test_asyncssh_plugin.py index 7177d5dfb2..fa6feac54d 100644 --- a/tests/transports/test_asyncssh_plugin.py +++ b/tests/transports/test_asyncssh_plugin.py @@ -34,7 +34,7 @@ async def test_semaphore_released_after_errors(self, tmp_path_factory): } async_transport = AsyncSshTransport(**transport_params) - with async_transport as transport: + async with async_transport as transport: # Each operation should fail but release the semaphore with pytest.raises(OSError, match='Error while downloading file'): await transport.getfile_async('non_existing', local_dir) @@ -46,7 +46,7 @@ async def test_semaphore_released_after_errors(self, tmp_path_factory): with pytest.raises(OSError, match='Error while downloading file'): await transport.getfile_async('non_existing', local_dir) - assert transport._semaphore._value == 1, 'Semaphore should be fully released' + assert async_transport._semaphore._value == 1, 'Semaphore should be fully released' @pytest.mark.asyncio async def test_semaphore_limits_concurrent_operations(self):