Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ repos:
# Check and update the uv lockfile
- id: uv-lock

- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
files: \.ipynb$

- repo: local

hooks:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/tutorials/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ If you are working on your own machine, note that the tutorial assumes that you
If this is not the case, consult the {ref}`getting started page<installation>`.
:::

:::{important}
If you are running this tutorial in a Jupyter notebook, make sure to call `load_profile()` in a **separate cell** before running any AiiDA engine processes (e.g. calculation functions or work chains).
:::

:::{tip}
This tutorial can be downloaded and run as a Jupyter Notebook: {nb-download}`basic.ipynb` {octicon}`download`
:::
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ dependencies:
- docstring_parser
- get-annotations~=0.1
- python-graphviz~=0.19
- plumpy~=0.25.0
- plumpy@ git+https://github.com/khsrali/plumpy.git@3.14
- ipython>=7.6
- jedi<0.19
- jinja2~=3.0
- kiwipy[rmq]~=0.8.4
- kiwipy[rmq]~=0.9.0
- importlib-metadata~=6.0
- numpy<3,>=1.21
- paramiko~=3.0
Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ dependencies = [
'docstring-parser',
'get-annotations~=0.1;python_version<"3.10"',
'graphviz~=0.19',
'plumpy~=0.25.0',
'plumpy @ git+https://github.com/khsrali/plumpy.git@3.14',
'ipython>=7.6',
'jedi<0.19',
'jinja2~=3.0',
'kiwipy[rmq]~=0.8.4',
'kiwipy[rmq]~=0.9.0',
'importlib-metadata~=6.0',
'numpy>=1.21,<3',
'paramiko~=3.0',
Expand Down Expand Up @@ -274,6 +274,9 @@ ssh_kerberos = [
tests = [
'aiida-core[atomic_tools,rest]',
'aiida-export-migration-tests==0.9.0',
'ipykernel~=6.9',
'nbclient~=0.10',
'nbformat~=5.10',
'pg8000~=1.13',
'pgtest~=1.3,>=1.3.1',
'pytest~=7.0',
Expand Down Expand Up @@ -412,7 +415,6 @@ filterwarnings = [
'ignore:The `Code` class is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning',
# https://github.com/aiidateam/plumpy/issues/283
'ignore:There is no current event loop:DeprecationWarning:plumpy',
'ignore:There is no current event loop:DeprecationWarning:nest_asyncio',
# spglib deprecation
'ignore:dict interface is deprecated:DeprecationWarning',
# https://github.com/aiidateam/archive-path/issues/21
Expand Down
4 changes: 3 additions & 1 deletion src/aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ async def run(self) -> 'ExitCode' | None:
# The remaining inputs have to be keyword arguments.
kwargs.update(**inputs)

result = self._func(*args, **kwargs)
from plumpy import run_with_portal

result = await run_with_portal(self._func, *args, **kwargs)

if result is None or isinstance(result, ExitCode): # type: ignore[redundant-expr]
return result # type: ignore[unreachable]
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/engine/processes/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional, Union

import kiwipy
from plumpy import get_or_create_event_loop

from aiida.orm import Node, load_node

Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
from .process import ProcessState

# create future in specified event loop
loop = loop if loop is not None else asyncio.get_event_loop()
loop = loop if loop is not None else get_or_create_event_loop()
super().__init__(loop=loop)

assert not (poll_interval is None and communicator is None), 'Must poll or have a communicator to use'
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import plumpy.persistence
import plumpy.processes
from kiwipy.communications import UnroutableError
from plumpy import run_until_complete
from plumpy.process_states import Finished, ProcessState
from plumpy.processes import ConnectionClosed # type: ignore[attr-defined]
from plumpy.processes import Process as PlumpyProcess
Expand Down Expand Up @@ -361,7 +362,7 @@ def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[b
coro = self._launch_task(task_kill_job, self.node, self.runner.transport)
self._cancelling_scheduler_job = asyncio.create_task(coro)
try:
self.loop.run_until_complete(self._cancelling_scheduler_job)
run_until_complete(self.loop, self._cancelling_scheduler_job)
except Exception as exc:
self.node.logger.error(f'While cancelling the scheduler job an error was raised: {exc}')
return False
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import typing as t

from plumpy import run_with_portal
from plumpy.persistence import auto_persist
from plumpy.process_states import Continue, Wait
from plumpy.processes import ProcessStateMachineMeta
Expand Down Expand Up @@ -299,7 +300,7 @@ def _update_process_status(self) -> None:
@Protect.final
async def run(self) -> t.Any:
self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type]
return self._do_step()
return await run_with_portal(self._do_step)

def _do_step(self) -> t.Any:
"""Execute the next step in the outline and return the result.
Expand Down
11 changes: 6 additions & 5 deletions src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union

import kiwipy
from plumpy import run_until_complete
from plumpy.communications import wrap_communicator
from plumpy.events import reset_event_loop_policy, set_event_loop_policy
from plumpy.events import get_or_create_event_loop
from plumpy.persistence import Persister
from plumpy.process_comms import RemoteProcessThreadController

Expand Down Expand Up @@ -81,8 +82,7 @@ def __init__(
broker_submit and persister is None
), 'Must supply a persister if you want to submit using communicator'

set_event_loop_policy()
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._loop = loop if loop else get_or_create_event_loop()
self._poll_interval = poll_interval
self._broker_submit = broker_submit
self._transport = transports.TransportQueue(self._loop)
Expand Down Expand Up @@ -156,16 +156,17 @@ def stop(self) -> None:

def run_until_complete(self, future: asyncio.Future) -> Any:
"""Run the loop until the future has finished and return the result."""

with utils.loop_scope(self._loop):
return self._loop.run_until_complete(future)
return run_until_complete(self._loop, future)

def close(self) -> None:
"""Close the runner by stopping the loop."""
assert not self._closed
self.stop()
if not self._loop.is_running():
self._loop.close()
reset_event_loop_policy()
# asyncio.set_event_loop_policy(None)
self._closed = True

def instantiate_process(self, process: TYPE_RUN_PROCESS, **inputs):
Expand Down
15 changes: 13 additions & 2 deletions src/aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import traceback
from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Dict, Hashable, Optional

from plumpy import get_or_create_event_loop

from aiida.orm import AuthInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -44,8 +46,8 @@ class TransportQueue:
"""

def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
""":param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied"""
self._loop = loop if loop is not None else asyncio.get_event_loop()
""":param loop: An asyncio event, will use `get_or_create_event_loop()` if not supplied"""
self._loop = loop if loop else get_or_create_event_loop()
self._transport_requests: Dict[Hashable, TransportRequest] = {}

@property
Expand All @@ -67,6 +69,15 @@ async def transport_task(transport_queue, authinfo):
:param authinfo: The authinfo to be used to get transport
:return: A future that can be yielded to give the transport
"""

from plumpy import ensure_portal

# NOTE: We need to ensure the portal here only because
# our scheduler has only a sync interface and _get_jobs_from_scheduler is using that
# if we ever provide a fully async scheduler interface then we can remove this here
# An issue is opened to reference this https://github.com/aiidateam/aiida-core/issues/7222
await ensure_portal()

open_callback_handle = None
transport_request = self._transport_requests.get(authinfo.pk, None)

Expand Down
8 changes: 5 additions & 3 deletions src/aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union

from plumpy import get_or_create_event_loop

if TYPE_CHECKING:
from aiida.orm import ProcessNode

Expand Down Expand Up @@ -125,10 +127,10 @@ def interruptable_task(
"""Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it.

:param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter
:param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop()
:param loop: the event loop in which to run the coroutine, by default uses get_or_create_event_loop()
:return: an InterruptableFuture
"""
loop = loop or asyncio.get_event_loop()
loop = loop or get_or_create_event_loop()
future = InterruptableFuture()

async def execute_coroutine():
Expand Down Expand Up @@ -252,7 +254,7 @@ def loop_scope(loop) -> Iterator[None]:

:param loop: The event loop to make current for the duration of the scope
"""
current = asyncio.get_event_loop()
current = get_or_create_event_loop()

try:
asyncio.set_event_loop(loop)
Expand Down
53 changes: 53 additions & 0 deletions src/aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,62 @@ def load_profile(self, profile: Union[None, str, 'Profile'] = None, allow_switch
# Check whether a development version is being run. Note that needs to be called after ``configure_logging``
# because this function relies on the logging being properly configured for the warning to show.
self.check_version()
self._setup_event_loop_in_ipython()

return self._profile

def _setup_event_loop_in_ipython(self) -> None:
"""Monkey-patch ``IPythonKernel.do_execute`` to ensure a portal.

When running inside an environment with an already-running event loop
(e.g. a Jupyter notebook kernel), this patches the kernel's
``do_execute`` to open a portal before executing each cell.
The portal is opended on whichever asyncio task ipykernel uses for that cell.

The patch takes effect from the **next cell**.
``load_profile()`` must therefore be called in a prior cell before
synchronous process execution.

This is a no-op if no event loop is running (scripts, CLI, daemon).
"""
import asyncio

try:
asyncio.get_running_loop()
except RuntimeError:
return # No running loop — not in Jupyter, nothing to do

self._patch_kernel_do_execute()

def _patch_kernel_do_execute(self) -> None:
"""Patch ``IPythonKernel.do_execute`` to ensure a portal before each cell."""
try:
from ipykernel.ipkernel import IPythonKernel

if getattr(IPythonKernel, '_aiida_portal_patched', False):
return # Already patched

from plumpy import ensure_portal

_orig_do_execute = IPythonKernel.do_execute

async def _patched_do_execute(self, code, silent, *args, **kwargs):
await ensure_portal()
return await _orig_do_execute(self, code, silent, *args, **kwargs)

IPythonKernel.do_execute = _patched_do_execute # type: ignore[method-assign]
IPythonKernel._aiida_portal_patched = True # type: ignore[attr-defined]
self.logger.debug(
'Patched IPythonKernel.do_execute for portal. '
'This should occur in a Jupyter kernel, and only once per kernel session.'
)
print(
'Warning: Synchronous process execution is available only from the next cell. '
'All aiida engine methods is only availabe from the next cell.'
)
except Exception:
self.logger.debug('Could not patch IPythonKernel for portal.', exc_info=True)

def reset_profile(self) -> None:
"""Close and reset any associated resources for the current profile."""
self.reset_broker()
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def clear_database_before_test_class(aiida_profile):
@pytest.fixture(scope='function')
def temporary_event_loop():
"""Create a temporary loop for independent test case"""
current = asyncio.get_event_loop()
current = plumpy.get_or_create_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
Expand Down
7 changes: 4 additions & 3 deletions src/aiida/transports/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,12 +1883,13 @@ class AsyncTransport(Transport):
"""

def run_command_blocking(self, func, *args, **kwargs):
"""The event loop must be the one of manager."""
"""Run an async transport method synchronously."""
from plumpy import run_until_complete

from aiida.manage import get_manager

loop = get_manager().get_runner()
return loop.run_until_complete(func(*args, **kwargs))
loop = get_manager().get_runner().loop
return run_until_complete(loop, func(*args, **kwargs))

def open(self):
return self.run_command_blocking(self.open_async)
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ class TestDbBackend(Enum):
PSQL = 'psql'


@pytest.fixture(autouse=True)
def _reset_runner(request):
yield
get_manager().reset_runner()


def pytest_collection_modifyitems(items, config):
"""Automatically generate markers for certain tests.

Expand Down
5 changes: 3 additions & 2 deletions tests/engine/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time

import pytest
from plumpy import get_or_create_event_loop

from aiida.engine.processes.calcjobs.manager import JobManager, JobsList
from aiida.engine.transports import TransportQueue
Expand All @@ -24,7 +25,7 @@ class TestJobManager:
@pytest.fixture(autouse=True)
def init_profile(self, aiida_localhost):
"""Initialize the profile."""
self.loop = asyncio.get_event_loop()
self.loop = get_or_create_event_loop()
self.transport_queue = TransportQueue(self.loop)
self.user = User.collection.get_default()
self.computer = aiida_localhost
Expand Down Expand Up @@ -54,7 +55,7 @@ class TestJobsList:
@pytest.fixture(autouse=True)
def init_profile(self, aiida_localhost):
"""Initialize the profile."""
self.loop = asyncio.get_event_loop()
self.loop = get_or_create_event_loop()
self.transport_queue = TransportQueue(self.loop)
self.user = User.collection.get_default()
self.computer = aiida_localhost
Expand Down
Loading
Loading