Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/aiida/engine/processes/calcjobs/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def _get_jobs_from_scheduler(self) -> Dict[str, 'JobInfo']:
:return: a mapping of job ids to :py:class:`~aiida.schedulers.datastructures.JobInfo` instances

"""
with self._transport_queue.request_transport(self._authinfo) as request:
async with self._transport_queue.request_transport(self._authinfo) as request:
self.logger.info('waiting for transport')
transport = await request

Expand Down
14 changes: 7 additions & 7 deletions src/aiida/engine/processes/calcjobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ async def task_upload_job(process: 'CalcJob', transport_queue: TransportQueue, c
authinfo = node.get_authinfo()

async def do_upload():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)

with SandboxFolder(filepath_sandbox) as folder:
Expand Down Expand Up @@ -144,7 +144,7 @@ async def task_submit_job(node: CalcJobNode, transport_queue: TransportQueue, ca
authinfo = node.get_authinfo()

async def do_submit():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
return execmanager.submit_calculation(node, transport)

Expand Down Expand Up @@ -252,7 +252,7 @@ async def task_monitor_job(
authinfo = node.get_authinfo()

async def do_monitor():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
return monitors.process(node, transport)

Expand Down Expand Up @@ -298,7 +298,7 @@ async def task_retrieve_job(
authinfo = node.get_authinfo()

async def do_retrieve():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
# Perform the job accounting and set it on the node if successful. If the scheduler does not implement this
# still set the attribute but set it to `None`. This way we can distinguish calculation jobs for which the
Expand Down Expand Up @@ -366,7 +366,7 @@ async def task_stash_job(node: CalcJobNode, transport_queue: TransportQueue, can
authinfo = node.get_authinfo()

async def do_stash():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)

logger.info(f'stashing calculation<{node.pk}>')
Expand Down Expand Up @@ -405,7 +405,7 @@ async def task_unstash_job(node: CalcJobNode, transport_queue: TransportQueue, c
authinfo = node.get_authinfo()

async def do_unstash():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)

logger.info(f'unstashing calculation<{node.pk}>')
Expand Down Expand Up @@ -454,7 +454,7 @@ async def task_kill_job(node: CalcJobNode, transport_queue: TransportQueue, canc
authinfo = node.get_authinfo()

async def do_kill():
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await cancellable.with_interrupt(request)
return execmanager.kill_calculation(node, transport)

Expand Down
33 changes: 15 additions & 18 deletions src/aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import contextvars
import logging
import traceback
from typing import TYPE_CHECKING, Awaitable, Dict, Hashable, Iterator, Optional
from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Dict, Hashable, Optional

from aiida.orm import AuthInfo

Expand Down Expand Up @@ -53,14 +53,14 @@ def loop(self) -> asyncio.AbstractEventLoop:
"""Get the loop being used by this transport queue"""
return self._loop

@contextlib.contextmanager
def request_transport(self, authinfo: AuthInfo) -> Iterator[Awaitable['Transport']]:
@contextlib.asynccontextmanager
async def request_transport(self, authinfo: AuthInfo) -> AsyncIterator[Awaitable['Transport']]:
"""Request a transport from an authinfo. Because the client is not allowed to
request a transport immediately they will instead be given back a future
that can be awaited to get the transport::

async def transport_task(transport_queue, authinfo):
with transport_queue.request_transport(authinfo) as request:
async with transport_queue.request_transport(authinfo) as request:
transport = await request
# Do some work with the transport

Expand All @@ -78,13 +78,14 @@ async def transport_task(transport_queue, authinfo):
transport = authinfo.get_transport()
safe_open_interval = transport.get_safe_open_interval()

def do_open():
"""Actually open the transport"""
async def do_open():
"""Wait for safe interval, then open the transport."""
await asyncio.sleep(safe_open_interval)
if transport_request.count > 0:
# The user still wants the transport so open it
_LOGGER.debug('Transport request opening transport for %s', authinfo)
try:
transport.open()
await transport.open_async()
except Exception as exception:
_LOGGER.error('exception occurred while trying to open transport:\n %s', exception)
transport_request.future.set_exception(exception)
Expand All @@ -99,8 +100,9 @@ def do_open():
# passed around to many places, including outside aiida-core (e.g. paramiko). Anyone keeping a reference
# to this handle would otherwise keep the Process context (and thus the process itself) in memory.
# See https://github.com/aiidateam/aiida-core/issues/4698
open_callback_handle = self._loop.call_later(safe_open_interval, do_open, context=contextvars.Context())

empty_ctx = contextvars.Context()
open_callback_handle = empty_ctx.run(self._loop.create_task, do_open())
# self._loop.create_task supports passing a context but only after Python 3.11+
try:
transport_request.count += 1
yield transport_request.future
Expand All @@ -118,18 +120,13 @@ def do_open():
# Check if there are no longer any users that want the transport
if transport_request.count == 0:
# IMPORTANT: Pop from _transport_requests BEFORE closing the transport.
# This prevents a race condition with async transports where:
# 1. close() is called, which for AsyncTransport uses run_until_complete(close_async)
# 2. With nest_asyncio (used by plumpy), this call yields back to the event loop
# 3. The event loop schedules close_async, then continues running another tasks
# - for example one that requests the transport which is scheduled to be closed
# 4. The task now using the transport to do some operation awaits,
# next the close_async task closes the transport while still in use -> error
# By poping first, new tasks will create a fresh transport request.
# This prevents a race condition where a new task could get a reference
# to a transport that is being closed. By popping first, new tasks will
# create a fresh transport request.
self._transport_requests.pop(authinfo.pk, None)

if transport_request.future.done():
_LOGGER.debug('Transport request closing transport for %s', authinfo)
transport_request.future.result().close()
await transport_request.future.result().close_async()
elif open_callback_handle is not None:
open_callback_handle.cancel()
68 changes: 60 additions & 8 deletions src/aiida/transports/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"""Transport interface."""

import abc
import asyncio
import fnmatch
import os
import re
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(self, *args, **kwargs):
self._logger_extra = None
self._is_open = False
self._enters = 0
self._open_lock = asyncio.Lock()

# for accessing the identity of the underlying machine
self.hostname = kwargs.get('machine')
Expand Down Expand Up @@ -171,21 +173,58 @@ def __enter__(self):

"""
# Keep track of how many times enter has been called
if self._track_enter():
self.open()
return self

def __exit__(self, type_, value, traceback):
"""Closes connections, if needed (used in 'with' statements)."""
if self._track_exit():
self.close()

async def __aenter__(self):
"""Async context manager entry. Opens the transport connection.

For sync transports, this just calls the sync open() method.
AsyncTransport subclasses override this to use async open.
"""
async with self._open_lock:
if self._track_enter():
self.open()
return self

async def __aexit__(self, type_, value, traceback):
"""Async context manager exit. Closes the transport connection if needed.

For sync transports, this just calls the sync close() method.
AsyncTransport subclasses override this to use async close.
"""
async with self._open_lock:
if self._track_exit():
self.close()

def _track_enter(self) -> bool:
"""Track a context manager entry and return whether open() needs to be called.

Manages the ``_enters`` reference counter. If the transport is already open
(e.g. opened externally), an extra count is added so the final exit won't close it.

Note: The caller must call ``open()``/``open_async()`` immediately when this returns True.
In async context, use ``self._open_lock`` to prevent potential concurrent open/close races.
"""
need_open = False
Copy link
Copy Markdown
Collaborator

@agoscinski agoscinski Feb 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not entirely sure about decoupling it. Its usage can be now

self._track_enter()
await something # another self._track_enter() happens that will also want to open it
self.open()

Yes it is internal, but we developers make also mistakes, and you dont save many lines of code by reusing this function.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, vaild point.
I don't know what's the best practise here. Also repeating this code three times, is subject to mistake. 🤔

Copy link
Copy Markdown
Collaborator Author

@khsrali khsrali Feb 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok @agoscinski, I added a clear docstring that how _track_enter should be used.
I also added asyncio.lock for any potential race opening condition --which I believe there's no double opening in engine. But better to be safe than sorry.

if self._enters == 0:
if self.is_open:
# Already open, so just add one to the entered counter
# this way on the final exit we will not close
self._enters += 1
else:
self.open()
need_open = True
self._enters += 1
return self
return need_open

def __exit__(self, type_, value, traceback):
"""Closes connections, if needed (used in 'with' statements)."""
def _track_exit(self) -> bool:
"""Track a context manager exit and return whether close() needs to be called."""
self._enters -= 1
if self._enters == 0:
self.close()
return self._enters == 0

@property
def is_open(self):
Expand Down Expand Up @@ -1857,6 +1896,19 @@ def open(self):
def close(self):
return self.run_command_blocking(self.close_async)

async def __aenter__(self):
"""Async context manager entry. Opens the transport connection."""
async with self._open_lock:
if self._track_enter():
await self.open_async()
return self

async def __aexit__(self, type_, value, traceback):
"""Async context manager exit. Closes the transport connection if needed."""
async with self._open_lock:
if self._track_exit():
await self.close_async()

def get(self, *args, **kwargs):
return self.run_command_blocking(self.get_async, *args, **kwargs)

Expand Down
14 changes: 7 additions & 7 deletions tests/engine/daemon/test_execmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def test_upload_local_copy_list(
node, calc_info = node_and_calc_info
calc_info.local_copy_list = [[folder.uuid] + local_copy_list]

with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)

# Check that none of the files were written to the repository of the calculation node, since they were communicated
Expand Down Expand Up @@ -217,7 +217,7 @@ async def test_upload_local_copy_list_files_folders(
(inputs['folder'].uuid, None, '.'),
]

with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)

# Check that none of the files were written to the repository of the calculation node, since they were communicated
Expand Down Expand Up @@ -249,7 +249,7 @@ async def test_upload_remote_symlink_list(
(node.computer.uuid, str(tmp_path / 'file_a.txt'), 'file_a.txt'),
]

with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)

filepath_workdir = pathlib.Path(node.get_remote_workdir())
Expand Down Expand Up @@ -314,7 +314,7 @@ async def test_upload_file_copy_operation_order(node_and_calc_info, tmp_path, or
if order is not None:
calc_info.file_copy_operation_order = order

with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(node, transport, calc_info, sandbox, inputs)
filepath = pathlib.Path(node.get_remote_workdir()) / 'file.txt'
assert filepath.is_file()
Expand Down Expand Up @@ -614,14 +614,14 @@ async def test_upload_combinations(
)
if expected_exception is not None:
with pytest.raises(expected_exception):
with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)

filepath_workdir = pathlib.Path(node.get_remote_workdir())

assert serialize_file_hierarchy(filepath_workdir, read_bytes=False) == expected_hierarchy
else:
with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(node, transport, calc_info, fixture_sandbox)

filepath_workdir = pathlib.Path(node.get_remote_workdir())
Expand Down Expand Up @@ -650,7 +650,7 @@ async def test_upload_calculation_portable_code(fixture_sandbox, node_and_calc_i
code_info.code_uuid = code.uuid
calc_info.codes_info = [code_info]

with node.computer.get_transport() as transport:
async with node.computer.get_transport() as transport:
await execmanager.upload_calculation(
node,
transport,
Expand Down
Loading