Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ repos:
# Check and update the uv lockfile
- id: uv-lock

- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
files: \.ipynb$

- repo: local

hooks:
Expand Down
37 changes: 37 additions & 0 deletions docs/source/howto/interact.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,43 @@ It is also possible to run ``verdi`` commands inside the notebook, for example:
%verdi status


Running AiiDA engine processes in notebooks
-------------------------------------------

AiiDA supports running engine processes (such as calculation functions and work chains) directly in Jupyter notebooks.
When :meth:`~aiida.manage.configuration.load_profile` is called inside a Jupyter notebook, AiiDA automatically sets up the necessary infrastructure to allow synchronous process execution within the notebook's event loop.

.. important::

``load_profile()`` must be called in a **separate cell** before any AiiDA engine processes can be executed.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@khsrali is this a new requirement or has this been the case even before?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's new..
We could really not find a better way to do this.
The thing is notebooks have their own running event loop. Since nest_asyncio is dropped, the only way for us to use their loop (since you can have only one running event loop in) is to open a greenback portal.
And that has to be called when the loop has started but before engine calls.
The most practical place to stuff this logic in, was in load_profile.
However, there's a technical issue with that: the greenback portals are only usable when you are back in a the async context. Basically that means either we had to changes it to something like await load_profile_async() --which defies the efforts of aiida to not expose async syntax to users-- Or to register that call on each cell execution. After many brainstorming we decided to go with the second solution.
The interface remains the same load_profile() but greenback portals become useful from the next execution cell. A minimum "backward incompatible" price that we'll had to pay

The setup takes effect from the **next cell** after loading the profile.

For example, load the profile in the first cell:

.. code-block:: ipython

In [1]: from aiida import load_profile
...: load_profile()

Then, in a subsequent cell, you can run engine processes as usual:

.. code-block:: ipython

In [2]: from aiida.engine import calcfunction
...: from aiida import orm
...:
...: @calcfunction
...: def add(x, y):
...: return orm.Int(x.value + y.value)
...:
...: result = add(orm.Int(3), orm.Int(4))
...: print(result)

.. warning::

Attempting to run engine processes in the **same cell** where ``load_profile()`` is called will raise an error.
Always ensure the profile is loaded in a separate cell.


.. _how-to:interact-restapi:

Expand Down
4 changes: 4 additions & 0 deletions docs/source/tutorials/basic.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ If you are working on your own machine, note that the tutorial assumes that you
If this is not the case, consult the {ref}`getting started page<installation>`.
:::

:::{important}
If you are running this tutorial in a Jupyter notebook, make sure to call `load_profile()` in a **separate cell** before running any AiiDA engine processes (e.g. calculation functions or work chains).
:::

:::{tip}
This tutorial can be downloaded and run as a Jupyter Notebook: {nb-download}`basic.ipynb` {octicon}`download`
:::
Expand Down
6 changes: 3 additions & 3 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ dependencies:
- circus~=0.19.0
- click-spinner~=0.1.8
- click<8.3,>=8.1.0
- disk-objectstore~=1.4.0
- disk-objectstore~=1.5.0
- docstring_parser
- get-annotations~=0.1
- python-graphviz~=0.19
- plumpy~=0.25.0
- plumpy~=0.26.0
- ipython>=7.6
- jedi<0.19
- jinja2~=3.0
- kiwipy[rmq]~=0.8.4
- kiwipy[rmq]~=0.9.0
- importlib-metadata~=6.0
- numpy<3,>=1.21
- paramiko~=3.0
Expand Down
10 changes: 6 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ dependencies = [
'circus~=0.19.0',
'click-spinner~=0.1.8',
'click>=8.1.0,<8.3',
'disk-objectstore~=1.4.0',
'disk-objectstore~=1.5.0',
'docstring-parser',
'get-annotations~=0.1;python_version<"3.10"',
'graphviz~=0.19',
'plumpy~=0.25.0',
'plumpy~=0.26.0',
'ipython>=7.6',
'jedi<0.19',
'jinja2~=3.0',
'kiwipy[rmq]~=0.8.4',
'kiwipy[rmq]~=0.9.0',
'importlib-metadata~=6.0',
'numpy>=1.21,<3',
'paramiko~=3.0',
Expand Down Expand Up @@ -274,6 +274,9 @@ ssh_kerberos = [
tests = [
'aiida-core[atomic_tools,rest]',
'aiida-export-migration-tests==0.9.0',
'ipykernel~=6.9',
'nbclient~=0.10',
'nbformat~=5.10',
'pg8000~=1.13',
'pgtest~=1.3,>=1.3.1',
'pytest~=7.0',
Expand Down Expand Up @@ -413,7 +416,6 @@ filterwarnings = [
'ignore:The `Code` class is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning',
# https://github.com/aiidateam/plumpy/issues/283
'ignore:There is no current event loop:DeprecationWarning:plumpy',
'ignore:There is no current event loop:DeprecationWarning:nest_asyncio',
# spglib deprecation
'ignore:dict interface is deprecated:DeprecationWarning',
# https://github.com/aiidateam/archive-path/issues/21
Expand Down
4 changes: 3 additions & 1 deletion src/aiida/engine/processes/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,9 @@ async def run(self) -> 'ExitCode' | None:
# The remaining inputs have to be keyword arguments.
kwargs.update(**inputs)

result = self._func(*args, **kwargs)
from plumpy import run_with_portal

result = await run_with_portal(self._func, *args, **kwargs)

if result is None or isinstance(result, ExitCode): # type: ignore[redundant-expr]
return result # type: ignore[unreachable]
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/engine/processes/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Optional, Union

import kiwipy
from plumpy import get_or_create_event_loop

from aiida.orm import Node, load_node

Expand Down Expand Up @@ -43,7 +44,7 @@ def __init__(
from .process import ProcessState

# create future in specified event loop
loop = loop if loop is not None else asyncio.get_event_loop()
loop = loop if loop is not None else get_or_create_event_loop()
super().__init__(loop=loop)

assert not (poll_interval is None and communicator is None), 'Must poll or have a communicator to use'
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import plumpy.persistence
import plumpy.processes
from kiwipy.communications import UnroutableError
from plumpy import run_until_complete
from plumpy.process_states import Finished, ProcessState
from plumpy.processes import ConnectionClosed # type: ignore[attr-defined]
from plumpy.processes import Process as PlumpyProcess
Expand Down Expand Up @@ -361,7 +362,7 @@ def kill(self, msg_text: str | None = None, force_kill: bool = False) -> Union[b
coro = self._launch_task(task_kill_job, self.node, self.runner.transport)
self._cancelling_scheduler_job = asyncio.create_task(coro)
try:
self.loop.run_until_complete(self._cancelling_scheduler_job)
run_until_complete(self.loop, self._cancelling_scheduler_job)
except Exception as exc:
self.node.logger.error(f'While cancelling the scheduler job an error was raised: {exc}')
return False
Expand Down
3 changes: 2 additions & 1 deletion src/aiida/engine/processes/workchains/workchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import logging
import typing as t

from plumpy import run_with_portal
from plumpy.persistence import auto_persist
from plumpy.process_states import Continue, Wait
from plumpy.processes import ProcessStateMachineMeta
Expand Down Expand Up @@ -299,7 +300,7 @@ def _update_process_status(self) -> None:
@Protect.final
async def run(self) -> t.Any:
self._stepper = self.spec().get_outline().create_stepper(self) # type: ignore[arg-type]
return self._do_step()
return await run_with_portal(self._do_step)

def _do_step(self) -> t.Any:
"""Execute the next step in the outline and return the result.
Expand Down
10 changes: 5 additions & 5 deletions src/aiida/engine/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from typing import Any, Callable, Dict, NamedTuple, Optional, Tuple, Type, Union

import kiwipy
from plumpy import run_until_complete
from plumpy.communications import wrap_communicator
from plumpy.events import reset_event_loop_policy, set_event_loop_policy
from plumpy.events import get_or_create_event_loop
from plumpy.persistence import Persister
from plumpy.process_comms import RemoteProcessThreadController

Expand Down Expand Up @@ -81,8 +82,7 @@ def __init__(
broker_submit and persister is None
), 'Must supply a persister if you want to submit using communicator'

set_event_loop_policy()
self._loop = loop if loop is not None else asyncio.get_event_loop()
self._loop = loop if loop else get_or_create_event_loop()
self._poll_interval = poll_interval
self._broker_submit = broker_submit
self._transport = transports.TransportQueue(self._loop)
Expand Down Expand Up @@ -156,16 +156,16 @@ def stop(self) -> None:

def run_until_complete(self, future: asyncio.Future) -> Any:
"""Run the loop until the future has finished and return the result."""

with utils.loop_scope(self._loop):
return self._loop.run_until_complete(future)
return run_until_complete(self._loop, future)

def close(self) -> None:
"""Close the runner by stopping the loop."""
assert not self._closed
self.stop()
if not self._loop.is_running():
self._loop.close()
reset_event_loop_policy()
self._closed = True

def instantiate_process(self, process: TYPE_RUN_PROCESS, **inputs):
Expand Down
15 changes: 13 additions & 2 deletions src/aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import traceback
from typing import TYPE_CHECKING, AsyncIterator, Awaitable, Dict, Hashable, Optional

from plumpy import get_or_create_event_loop

from aiida.orm import AuthInfo

if TYPE_CHECKING:
Expand Down Expand Up @@ -44,8 +46,8 @@ class TransportQueue:
"""

def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None):
""":param loop: An asyncio event, will use `asyncio.get_event_loop()` if not supplied"""
self._loop = loop if loop is not None else asyncio.get_event_loop()
""":param loop: An asyncio event, will use `get_or_create_event_loop()` if not supplied"""
self._loop = loop if loop else get_or_create_event_loop()
self._transport_requests: Dict[Hashable, TransportRequest] = {}

@property
Expand All @@ -67,6 +69,15 @@ async def transport_task(transport_queue, authinfo):
:param authinfo: The authinfo to be used to get transport
:return: A future that can be yielded to give the transport
"""

from plumpy import ensure_portal

# NOTE: We need to ensure the portal here only because
# our scheduler has only a sync interface and _get_jobs_from_scheduler is using that
# if we ever provide a fully async scheduler interface then we can remove this here
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand. We need to call ensure_portal, when we switch to a different task than the task in execute (because this one has an open portal), and when we require a nested sync->async call. So in scheduler we have such a case? But why do we need to open it in transport then? Aren't the two classes decoupled?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when a CalcJob polls the scheduler, the update gets scheduled as a new asyncio task (via this call_later). And this new task doesn't inherit the portal that the original process had from execute(). The main problem is that scheduler.get_jobs() is sync, and internally it calls transport.exec_command_wait(). For which it uses run_until_complete().

Once scheduler interface is async, as well we can get rid of the ensure portal here

# An issue is opened to reference this https://github.com/aiidateam/aiida-core/issues/7222
await ensure_portal()

open_callback_handle = None
transport_request = self._transport_requests.get(authinfo.pk, None)

Expand Down
8 changes: 5 additions & 3 deletions src/aiida/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterator, List, Optional, Tuple, Type, Union

from plumpy import get_or_create_event_loop

if TYPE_CHECKING:
from aiida.orm import ProcessNode

Expand Down Expand Up @@ -125,10 +127,10 @@ def interruptable_task(
"""Turn the given coroutine into an interruptable task by turning it into an InterruptableFuture and returning it.

:param coro: the coroutine that should be made interruptable with object of InterutableFuture as last paramenter
:param loop: the event loop in which to run the coroutine, by default uses asyncio.get_event_loop()
:param loop: the event loop in which to run the coroutine, by default uses get_or_create_event_loop()
:return: an InterruptableFuture
"""
loop = loop or asyncio.get_event_loop()
loop = loop or get_or_create_event_loop()
future = InterruptableFuture()

async def execute_coroutine():
Expand Down Expand Up @@ -252,7 +254,7 @@ def loop_scope(loop) -> Iterator[None]:

:param loop: The event loop to make current for the duration of the scope
"""
current = asyncio.get_event_loop()
current = get_or_create_event_loop()

try:
asyncio.set_event_loop(loop)
Expand Down
49 changes: 49 additions & 0 deletions src/aiida/manage/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,58 @@ def load_profile(self, profile: Union[None, str, 'Profile'] = None, allow_switch
# Check whether a development version is being run. Note that needs to be called after ``configure_logging``
# because this function relies on the logging being properly configured for the warning to show.
self.check_version()
self._setup_event_loop_in_ipython()

return self._profile

def _setup_event_loop_in_ipython(self) -> None:
"""Monkey-patch ``IPythonKernel.do_execute`` to ensure a portal.

When running inside an environment with an already-running event loop
(e.g. a Jupyter notebook kernel), this patches the kernel's
``do_execute`` to open a portal before executing each cell.
The portal is opended on whichever asyncio task ipykernel uses for that cell.

The patch takes effect from the **next cell**.
``load_profile()`` must therefore be called in a separate cell before
synchronous process execution.

This is a no-op if no event loop is running (scripts, CLI, daemon).
"""
import asyncio

try:
asyncio.get_running_loop()
except RuntimeError:
return # No running loop — not in Jupyter, nothing to do

self._patch_kernel_do_execute()

def _patch_kernel_do_execute(self) -> None:
"""Patch ``IPythonKernel.do_execute`` to ensure a portal before each cell."""
try:
from ipykernel.ipkernel import IPythonKernel

if getattr(IPythonKernel, '_aiida_portal_patched', False):
return # Already patched

from plumpy import ensure_portal

_orig_do_execute = IPythonKernel.do_execute

async def _patched_do_execute(self, code, silent, *args, **kwargs):
await ensure_portal()
return await _orig_do_execute(self, code, silent, *args, **kwargs)

IPythonKernel.do_execute = _patched_do_execute # type: ignore[method-assign]
IPythonKernel._aiida_portal_patched = True # type: ignore[attr-defined]
self.logger.debug(
'Patched IPythonKernel.do_execute for portal. '
'This should occur in a Jupyter kernel, and only once per kernel session.'
)
except Exception:
self.logger.debug('Could not patch IPythonKernel for portal.', exc_info=True)

def reset_profile(self) -> None:
"""Close and reset any associated resources for the current profile."""
self.reset_broker()
Expand Down
2 changes: 1 addition & 1 deletion src/aiida/manage/tests/pytest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def clear_database_before_test_class(aiida_profile):
@pytest.fixture(scope='function')
def temporary_event_loop():
"""Create a temporary loop for independent test case"""
current = asyncio.get_event_loop()
current = plumpy.get_or_create_event_loop()
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/repository/backend/disk_object_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def open(self, key: str) -> t.Iterator[t.BinaryIO]:
yield t.cast(t.BinaryIO, handle)

def iter_object_streams(self, keys: t.Iterable[str]) -> t.Iterator[t.Tuple[str, t.BinaryIO]]:
with self._container.get_objects_stream_and_meta(keys) as triplets: # type: ignore[arg-type]
with self._container.get_objects_stream_and_meta(keys) as triplets:
for key, stream, _ in triplets:
assert stream is not None
yield key, stream # type: ignore[misc]
Expand Down Expand Up @@ -199,7 +199,7 @@ def maintain(
if not dry_run:
with get_progress_reporter()(total=1) as progress:
callback = create_callback(progress)
container.pack_all_loose(compress=compress_mode, callback=callback) # type: ignore[arg-type]
container.pack_all_loose(compress=compress_mode, callback=callback)

if do_repack:
files_numb = container.count_objects().packed
Expand All @@ -208,7 +208,7 @@ def maintain(
if not dry_run:
with get_progress_reporter()(total=1) as progress:
callback = create_callback(progress)
container.repack(callback=callback) # type: ignore[arg-type]
container.repack(callback=callback)

if clean_storage:
logger.report(f'Cleaning the repository database (with `vacuum={do_vacuum}`) ...')
Expand Down
Loading
Loading