diff --git a/alchemiscale/interface/api.py b/alchemiscale/interface/api.py index 341069e9..d5edf34f 100644 --- a/alchemiscale/interface/api.py +++ b/alchemiscale/interface/api.py @@ -31,7 +31,7 @@ from ..settings import get_base_api_settings from ..storage.statestore import Neo4jStore from ..storage.objectstore import S3ObjectStore -from ..storage.models import TaskStatusEnum, StrategyState +from ..storage.models import NetworkStateEnum, TaskStatusEnum, StrategyState from ..models import Scope, ScopedKey from ..security.models import TokenData, CredentialedUserIdentity @@ -138,6 +138,55 @@ async def create_network( return an_sk +@router.post("/networks/merge", response_model=ScopedKey) +async def merge_networks( + *, + networks: list[str] = Body(embed=True), + name: str = Body(embed=True), + scope: dict = Body(embed=True), + state: str = Body(embed=True, default=NetworkStateEnum.active.value), + n4js: Neo4jStore = Depends(get_n4js_depends), + token: TokenData = Depends(get_token_data_depends), +): + # validate the destination scope first + try: + target_scope = Scope(**scope) + except (TypeError, ValidationError) as e: + raise HTTPException( + status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(e), + ) + validate_scopes(target_scope, token) + + # validate each source network's scope is accessible to the token + network_sks = [] + for network in networks: + try: + network_sk = ScopedKey.from_str(network) + except ValueError as e: + raise HTTPException( + status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=e.args[0], + ) + validate_scopes(network_sk.scope, token) + network_sks.append(network_sk) + + try: + an_sk = n4js.merge_networks( + network_scoped_keys=network_sks, + name=name, + scope=target_scope, + state=state, + ) + except ValueError as e: + raise HTTPException( + status_code=http_status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=e.args[0], + ) + + return an_sk + + @router.post("/bulk/networks/state/set") def set_networks_state( *, diff --git a/alchemiscale/interface/client.py b/alchemiscale/interface/client.py index f53859ba..c0a0a0a3 100644 --- a/alchemiscale/interface/client.py +++ b/alchemiscale/interface/client.py @@ -187,6 +187,103 @@ def post(): return ScopedKey.from_dict(scoped_key) + def merge_networks( + self, + networks: list[ScopedKey], + name: str, + scope: Scope, + state: NetworkStateEnum | str = NetworkStateEnum.active, + visualize: bool = True, + ) -> ScopedKey: + """Merge multiple existing AlchemicalNetworks into a new AlchemicalNetwork. + + The resulting AlchemicalNetwork contains the union of all + Transformations and NonTransformations from the source networks. + Existing Tasks for those transformations that are in ``complete`` or + ``error`` state are cloned into the new network's scope along with + their associated ProtocolDAGResultRefs, so previously-computed results + do not need to be re-run. + + Cloned Tasks are wired to their Transformations via ``PERFORMS`` and + are reachable through standard network traversals + (``get_network_tasks``, ``get_network_results``, etc.). They are + intentionally **not** actioned to the new network's TaskHub; to + retry errored Tasks on the merged network, call + :meth:`action_tasks` with the merged network's ScopedKey after the + merge completes. + + Parameters + ---------- + networks + The ScopedKeys of the AlchemicalNetworks to merge. The source + networks may live in different Scopes; the caller must have access + to each. + name + The name of the new AlchemicalNetwork. + scope + The Scope in which to create the new AlchemicalNetwork. + This must be a *specific* Scope; it must not contain wildcards. + state + The starting state of the new AlchemicalNetwork in the database. + See :meth:`AlchemiscaleClient.set_network_state` for valid states. + Defaults to ``"active"``. + visualize + If ``True``, show submission progress indicator. + + Returns + ------- + ScopedKey + The ScopedKey of the new, merged AlchemicalNetwork. + """ + if not scope.specific(): + raise ValueError( + f"`scope` '{scope}' contains wildcards ('*'); `scope` must be *specific*" + ) + + if not networks: + raise ValueError("`networks` must contain at least one ScopedKey") + + network_sks = [ + sk if isinstance(sk, ScopedKey) else ScopedKey.from_str(sk) + for sk in networks + ] + + for network_sk in network_sks: + if network_sk.qualname not in ("AlchemicalNetwork",): + raise ValueError( + f"ScopedKey '{network_sk}' does not refer to an AlchemicalNetwork" + ) + + state = NetworkStateEnum(state) + + data = dict( + networks=[str(sk) for sk in network_sks], + name=name, + scope=scope.to_dict(), + state=state.value, + ) + + def post(): + return self._post_resource("/networks/merge", data) + + if visualize: + from rich.progress import Progress + + with Progress(*self._rich_waiting_columns(), transient=False) as progress: + task = progress.add_task( + f"Merging [bold]{len(network_sks)}[/bold] networks into " + f"[bold]'{name}'[/bold] in scope [bold]'{scope}'[/bold]...", + total=None, + ) + + scoped_key = post() + progress.start_task(task) + progress.update(task, total=1, completed=1) + else: + scoped_key = post() + + return ScopedKey.from_dict(scoped_key) + def set_network_state( self, network: ScopedKey, state: NetworkStateEnum | str ) -> ScopedKey | None: diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 343bdcde..5f94f83f 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -7,6 +7,7 @@ import abc import bisect import datetime +from dataclasses import dataclass, field from contextlib import contextmanager import json import re @@ -24,7 +25,12 @@ Protocol, ) from gufe.settings import SettingsBaseModel -from gufe.tokenization import GufeTokenizable, GufeKey, JSON_HANDLER, KeyedChain +from gufe.tokenization import ( + GufeTokenizable, + GufeKey, + JSON_HANDLER, + KeyedChain, +) from gufe.protocols import ProtocolUnitFailure from neo4j import Transaction, GraphDatabase, Driver, NotificationDisabledClassification @@ -78,6 +84,152 @@ def get_n4js(settings: Neo4jStoreSettings): class Neo4JStoreError(Exception): ... +def _is_transformation_keyed_dict(keyed_dict: dict) -> bool: + """Return ``True`` if ``keyed_dict`` represents a ``Transformation`` or ``NonTransformation``. + + Used as the predicate for ``KeyedChain.decode_subchains`` when walking + an ``AlchemicalNetwork``'s keyed chain inside :meth:`Neo4jStore.merge_networks`. + """ + return keyed_dict.get("__qualname__") in ("Transformation", "NonTransformation") + + +@dataclass +class _TransformationData: + """Bookkeeping for one Transformation as it is reconstructed during + :meth:`Neo4jStore.merge_networks`. + + Attributes + ---------- + transformation + The decoded ``Transformation`` (or ``NonTransformation``) + ``GufeTokenizable``. + task_tree + Flat list of Neo4j records, one per ``Task`` associated with this + ``Transformation`` in any of the source networks. Each record carries: + + - ``tf_key``: this ``Transformation``'s decoded gufe key (string) + - ``task``: the full ``Task`` node + - ``extended_task``: optional ``Task`` node that this ``Task`` extends + - ``pdrrs``: list of ``ProtocolDAGResultRef`` nodes for this ``Task`` + known_scoped_keys + All ``ScopedKey``\\ s representing this ``Transformation`` across the + source networks. Multiple ``ScopedKey``\\ s can map to a single + decoded ``Transformation`` if the same content was serialized under + different gufe versions. + """ + + transformation: Transformation + task_tree: list = field(default_factory=list) + known_scoped_keys: list = field(default_factory=list) + + def add_known_scoped_key(self, key, scope): + self.known_scoped_keys.append( + ScopedKey(gufe_key=GufeKey(key), **scope.to_dict()) + ) + + @staticmethod + def update_task_trees(transformation_data: list, statestore): + """Given a list of ``_TransformationData``, extract all necessary + info from Neo4j and load the task trees. + """ + key_to_data_map = {str(td.transformation.key): td for td in transformation_data} + # prepare for unwind clause, include transformation key + # for updating each entry of transformation_data for + # each scoped key + transformation_sk_pairs = [ + [str(td.transformation.key), str(sk)] + for td in transformation_data + for sk in td.known_scoped_keys + ] + query = """ + UNWIND $tf_sk_pairs as pairs + WITH pairs[0] AS tf_key, pairs[1] AS tf_scoped_key + MATCH (task:Task)-[:PERFORMS]->(:Transformation|NonTransformation {`_scoped_key`: tf_scoped_key}) + WHERE task.status IN ["complete", "error"] + OPTIONAL MATCH (task)-[:EXTENDS]->(extended_task:Task) + OPTIONAL MATCH (task)-[:RESULTS_IN]->(pdrr:ProtocolDAGResultRef) + RETURN tf_key, task, extended_task as extended_task, collect(pdrr) as pdrrs + """ + results = statestore.execute_query( + query, tf_sk_pairs=transformation_sk_pairs + ).records + + for record in results: + key_to_data_map[record["tf_key"]].task_tree.append(record) + + def to_subgraph(self, target_scope, statestore, subchain_cache): + """Create a subgraph anchored at the Transformation node and + iteratively add Task and PDRR nodes with their corresponding + relationships. + + Each cloned Task is wired back to the Transformation via a + ``PERFORMS`` edge so the standard + ``(:AlchemicalNetwork)-[:DEPENDS_ON]->(:Transformation)<-[:PERFORMS]-(:Task)`` + traversal keeps working on the merged network. Note that cloned + Tasks are intentionally **not** actioned to the merged network's + TaskHub; see :meth:`Neo4jStore.merge_networks` for the rationale. + """ + # if there are no tasks, return an empty subgraph; the + # Transformation node already exists in the surrounding + # merged-network subgraph + if not self.task_tree: + return Subgraph() + + # build the Transformation node as the PERFORMS anchor for cloned + # Tasks. ``subchain_cache`` holds the re-serialized chain for the + # decoded ``Transformation``; the resulting Node's ``_scoped_key`` + # matches the one already produced by the combined + # AlchemicalNetwork's keyed chain and will dedupe against it + # during ``merge_subgraph``. + _, tf_node, _ = statestore._keyed_chain_to_subgraph( + subchain_cache[self.transformation], target_scope + ) + subgraph = Subgraph() | tf_node + + scope_props = { + "_org": target_scope.org, + "_campaign": target_scope.campaign, + "_project": target_scope.project, + } + + def record_to_node(record): + # create node from a neo4j record with updated scoped key + scoped_key = ScopedKey( + gufe_key=record["_gufe_key"], **target_scope.to_dict() + ) + return Node( + *record.labels, + **record._properties | scope_props | {"_scoped_key": str(scoped_key)}, + ) + + # process each task found. Each record represents a single task. + for record in self.task_tree: + # update task node to have new scoped key + task_node = record_to_node(record["task"]) + # wire the cloned Task back to its Transformation; without + # this edge the Task is unreachable from get_network_tasks + # and every other PERFORMS-based traversal + subgraph |= Relationship.type("PERFORMS")(task_node, tf_node, **scope_props) + # create the task node this task extends if it exists + etask_node = ( + None + if not record["extended_task"] + else record_to_node(record["extended_task"]) + ) + if etask_node: + subgraph |= Relationship.type("EXTENDS")( + etask_node, task_node, **scope_props + ) + # clone all result refs for the task + for pdrr_record in record["pdrrs"]: + pdrr_node = record_to_node(pdrr_record) + subgraph |= Relationship.type("RESULTS_IN")( + task_node, pdrr_node, **scope_props + ) + + return subgraph + + class AlchemiscaleStateStore(abc.ABC): ... @@ -876,6 +1028,118 @@ def delete_network( """ raise NotImplementedError + def merge_networks( + self, + network_scoped_keys: list[ScopedKey], + name: str, + scope: Scope, + state: NetworkStateEnum | str = NetworkStateEnum.active, + ) -> ScopedKey: + """Merge multiple ``AlchemicalNetwork`` nodes into a new ``AlchemicalNetwork``. + + Each ``Transformation`` / ``NonTransformation`` in the input + networks is included exactly once in the new network. ``Task``\\ s + on the source networks that are in ``complete`` or ``error`` + state are cloned into the new network's ``Scope`` along with + their ``ProtocolDAGResultRef``\\ s and ``EXTENDS`` relationships, + and are wired to their ``Transformation`` via ``PERFORMS`` so + they are reachable from the standard network traversals. + + Cloned ``Task``\\ s are intentionally **not** actioned to the new + network's ``TaskHub``. Users wanting to retry errored tasks on + the merged network should call :meth:`action_tasks` themselves + with the merged network's ``TaskHub`` ``ScopedKey``. + + Parameters + ---------- + network_scoped_keys + List of ``AlchemicalNetwork`` ``ScopedKey`` objects to merge. + name + The name of the new ``AlchemicalNetwork``. + scope + The ``Scope`` of the new ``AlchemicalNetwork``. + state + The starting state for the new ``AlchemicalNetwork``'s + ``NetworkMark``. Defaults to ``NetworkStateEnum.active``. + + Returns + ------- + The ``ScopedKey`` of the new ``AlchemicalNetwork`` in the database. + """ + # Collect keyed chain representation for all alchemical networks, + # gathering every missing ScopedKey up front so callers passing + # many SKs get a single, complete error. + network_keyed_chains: list[tuple[Scope, KeyedChain]] = [] + missing: list[ScopedKey] = [] + for network_scoped_key in network_scoped_keys: + try: + keyed_chain = self.get_keyed_chain(network_scoped_key) + except KeyError: + missing.append(network_scoped_key) + continue + network_keyed_chains.append((network_scoped_key.scope, keyed_chain)) + if missing: + joined = ", ".join(str(sk) for sk in missing) + raise ValueError( + f"The following ScopedKey(s) were not found in the database: {joined}" + ) + + # Map decoded Transformation / NonTransformation objects to all of + # their original database GufeKeys (potentially across multiple + # source networks, and including duplicates introduced by minor + # serialization-version drift). These original keys are needed to + # locate Tasks associated with each Transformation in the source + # networks. We key the dedup map on ``transformation.key`` (the + # decoded GufeKey) to avoid O(N^2) GufeTokenizable equality checks + # for large networks. + transformation_data: dict[GufeKey, _TransformationData] = {} + subchain_cache: dict[GufeTokenizable, KeyedChain] = {} + for network_scope, network_keyed_chain in network_keyed_chains: + # database keys for Transformations / NonTransformations in this + # source network's chain, in chain order + database_keys = [ + gufe_key + for gufe_key, keyed_dict in network_keyed_chain + if _is_transformation_keyed_dict(keyed_dict) + ] + # decoded Transformation / NonTransformation objects, yielded + # by decode_subchains in the same chain order as the predicate + # selects above (gufe contract); decode_subchains also shares + # a tokenizable_map across yields so common dependencies are + # decoded only once per source network. + transformations = network_keyed_chain.decode_subchains( + _is_transformation_keyed_dict + ) + for database_key, transformation in zip(database_keys, transformations): + subchain_cache[transformation] = KeyedChain.from_gufe(transformation) + data = transformation_data.get(transformation.key) + if data is None: + data = _TransformationData(transformation) + transformation_data[transformation.key] = data + data.add_known_scoped_key(database_key, network_scope) + + # Collect all transformation gufe objects and collect into a new set of edges + _TransformationData.update_task_trees(list(transformation_data.values()), self) + new_edges = [td.transformation for td in transformation_data.values()] + # Make new alchemiscale network with these edges + combined_alchemical_network = AlchemicalNetwork(edges=new_edges, name=name) + an_subgraph, an_node, an_sk = self._keyed_chain_to_subgraph( + KeyedChain.from_gufe(combined_alchemical_network), scope + ) + # create and fold in taskhub and network mark supporting nodes + an_subgraph |= ( + self.create_network_mark_subgraph(an_node, state=state)[0] + | self.create_taskhub_subgraph(an_node)[0] + ) + # create and fold in all task and results data + for td in transformation_data.values(): + an_subgraph |= td.to_subgraph(scope, self, subchain_cache) + + # merge the new network into neo4j + with self.transaction() as tx: + merge_subgraph(tx, an_subgraph, "GufeTokenizable", "_scoped_key") + return an_sk + def get_network_state(self, networks: list[ScopedKey]) -> list[str | None]: """Get the states of a group of networks. @@ -2853,7 +3117,10 @@ def task_count(task_dict: dict): ## tasks - def _validate_extends_tasks(self, task_list) -> dict[str, tuple[Node, str]]: + def _validate_extends_tasks( + self, + task_list, + ) -> dict[str, tuple[Node, str]]: if not task_list: return {} @@ -2943,7 +3210,7 @@ def create_tasks( transformation_map[transformation.qualname][1].append(extends[i]) extends_nodes = self._validate_extends_tasks( - [_extends for _extends in extends if _extends is not None] + [_extends for _extends in extends if _extends is not None], ) subgraph = Subgraph() diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index 390acb71..1fa56b1a 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -2,6 +2,7 @@ import datetime from time import sleep import os +import uuid from pathlib import Path from itertools import chain import json @@ -199,6 +200,230 @@ def test_create_network( # common with an existing network # user_client.create_network( + def test_merge_networks( + self, + scope_test, + multiple_scopes, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + network_tyk2, + ): + # gather source AlchemicalNetwork ScopedKeys across all scopes + # n4js_preloaded creates `network_tyk2` and a trimmed copy named + # "incomplete" in each of `multiple_scopes` + source_sks = user_client.query_networks(state=None) + + # destination scope: a new project under the existing org/campaign + merge_scope = Scope( + org=scope_test.org, + campaign=scope_test.campaign, + project="merged_project", + ) + + merged_sk = user_client.merge_networks( + networks=source_sks, + name="merged_tyk2", + scope=merge_scope, + visualize=False, + ) + + assert isinstance(merged_sk, ScopedKey) + assert merged_sk.scope == merge_scope + assert merged_sk.qualname == "AlchemicalNetwork" + + # the merged network should exist + assert user_client.check_exists(merged_sk) + + # the merged network should appear in queries (defaults to active state) + all_active_sks = user_client.query_networks() + assert merged_sk in all_active_sks + + # the merged network should contain the union of all source edges; + # since `network_tyk2` is a superset of `incomplete`, the union equals + # `network_tyk2.edges` + merged_network = user_client.get_network(merged_sk) + assert merged_network.name == "merged_tyk2" + assert len(merged_network.edges) == len(network_tyk2.edges) + assert {t.key for t in merged_network.edges} == { + t.key for t in network_tyk2.edges + } + + @pytest.mark.parametrize("state", ["active", "inactive"]) + def test_merge_networks_respects_state( + self, + state, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + ): + """The state parameter must control the NetworkMark on the merged + AlchemicalNetwork the same way it does on create_network.""" + source_sks = user_client.query_networks(scope=scope_test, state=None) + assert source_sks + + merge_scope = Scope( + org=scope_test.org, + campaign=scope_test.campaign, + project=f"merged_state_{state}", + ) + merged_sk = user_client.merge_networks( + networks=source_sks, + name=f"merged_state_{state}", + scope=merge_scope, + state=state, + visualize=False, + ) + + assert user_client.get_network_state(merged_sk) == state + + def test_merge_networks_rejects_wildcard_scope( + self, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + ): + source_sks = user_client.query_networks(state=None) + with pytest.raises(ValueError, match="wildcards"): + user_client.merge_networks( + networks=source_sks, + name="should_fail", + scope=Scope(org="test_org"), + visualize=False, + ) + + def test_merge_networks_rejects_empty_list( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + ): + with pytest.raises(ValueError, match="at least one"): + user_client.merge_networks( + networks=[], + name="should_fail", + scope=scope_test, + visualize=False, + ) + + def test_merge_networks_rejects_non_network_scoped_key( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + ): + # pass a Transformation ScopedKey rather than an AlchemicalNetwork one + tf_sks = user_client.query_transformations(scope=scope_test) + assert tf_sks + with pytest.raises(ValueError, match="does not refer to an AlchemicalNetwork"): + user_client.merge_networks( + networks=[tf_sks[0]], + name="should_fail", + scope=scope_test, + visualize=False, + ) + + def test_merge_networks_preserves_tasks_and_results( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + network_tyk2, + ): + """The merged network must carry over Tasks in complete/error state + along with their ProtocolDAGResultRefs, cloned into the new scope.""" + # pick a source network in scope_test + source_sks = user_client.query_networks(scope=scope_test, state=None) + assert source_sks + source_sk = source_sks[0] + + # create Tasks on two of its Transformations directly through n4js so + # we can drive them to completed/errored states with PDRRs without + # actually executing protocols + transformation_sks = n4js_preloaded.get_network_transformations(source_sk) + assert len(transformation_sks) >= 2 + + task_sks = n4js_preloaded.create_tasks(transformation_sks[:2]) + + # task 0: complete, ok result + n4js_preloaded.set_task_running(task_sks[:1]) + n4js_preloaded.set_task_complete(task_sks[:1]) + ok_pdrr = ProtocolDAGResultRef( + obj_key=f"ProtocolDAGResult-{uuid.uuid4()}", + scope=task_sks[0].scope, + ok=True, + ) + n4js_preloaded.set_task_result(task_sks[0], ok_pdrr) + + # task 1: error, failure result + n4js_preloaded.set_task_running(task_sks[1:2]) + n4js_preloaded.set_task_error(task_sks[1:2]) + err_pdrr = ProtocolDAGResultRef( + obj_key=f"ProtocolDAGResult-{uuid.uuid4()}", + scope=task_sks[1].scope, + ok=False, + ) + n4js_preloaded.set_task_result(task_sks[1], err_pdrr) + + # merge into a fresh project under the same org/campaign + merge_scope = Scope( + org=scope_test.org, + campaign=scope_test.campaign, + project="merged_with_results", + ) + merged_sk = user_client.merge_networks( + networks=[source_sk], + name="merged_with_results", + scope=merge_scope, + visualize=False, + ) + assert user_client.check_exists(merged_sk) + + # both Tasks should appear in the new scope, with their original statuses + task_records = n4js_preloaded.execute_query( + """ + MATCH (t:Task {`_project`: $project}) + RETURN t.status AS status + """, + project=merge_scope.project, + ).records + assert sorted(r["status"] for r in task_records) == ["complete", "error"] + + # one ok and one not-ok PDRR should appear in the new scope, with their + # original object keys preserved + pdrr_records = n4js_preloaded.execute_query( + """ + MATCH (pdrr:ProtocolDAGResultRef {`_project`: $project}) + RETURN pdrr.ok AS ok, pdrr.obj_key AS obj_key + """, + project=merge_scope.project, + ).records + assert len(pdrr_records) == 2 + oks = sorted(r["ok"] for r in pdrr_records) + assert oks == [False, True] + obj_keys = {r["obj_key"] for r in pdrr_records} + assert obj_keys == {ok_pdrr.obj_key, err_pdrr.obj_key} + + # cloned PDRRs must be wired to the cloned Tasks in the new scope + linked = n4js_preloaded.execute_query( + """ + MATCH (t:Task {`_project`: $project})-[:RESULTS_IN]-> + (pdrr:ProtocolDAGResultRef {`_project`: $project}) + RETURN t.status AS status, pdrr.ok AS ok + """, + project=merge_scope.project, + ).records + pairs = sorted((r["status"], r["ok"]) for r in linked) + assert pairs == [("complete", True), ("error", False)] + + # the cloned Tasks must be reachable from the merged AlchemicalNetwork + # via the standard PERFORMS traversal that the user-facing API uses; + # this catches any case where Tasks are written to the new scope but + # not wired back to their Transformations in the merged network + merged_task_sks = user_client.get_network_tasks(merged_sk) + assert len(merged_task_sks) == 2 + assert all(sk.scope == merge_scope for sk in merged_task_sks) + statuses = sorted(user_client.get_tasks_status(merged_task_sks)) + assert statuses == ["complete", "error"] + def test_check_exists( self, scope_test, diff --git a/alchemiscale/tests/integration/interface/test_api.py b/alchemiscale/tests/integration/interface/test_api.py index 21322c02..f1427147 100644 --- a/alchemiscale/tests/integration/interface/test_api.py +++ b/alchemiscale/tests/integration/interface/test_api.py @@ -103,6 +103,97 @@ def test_create_network_bad_scope( assert str(bad_scope) in details assert str(scope_test) in details + def test_merge_networks( + self, n4js_preloaded, test_client, network_tyk2, scope_test + ): + n4js = n4js_preloaded + + # source networks in scope_test (pre-loaded by n4js_preloaded) + source_sks = n4js.query_networks(scope=scope_test) + assert len(source_sks) >= 2 + + # destination scope: a new project under the same org/campaign so the + # test_client's scope_test token has access + merge_scope_dict = { + "org": scope_test.org, + "campaign": scope_test.campaign, + "project": scope_test.project, + } + + headers = {"Content-type": "application/json"} + data = dict( + networks=[str(sk) for sk in source_sks], + name="api_merged", + scope=merge_scope_dict, + ) + jsondata = json.dumps(data, cls=JSON_HANDLER.encoder) + + response = test_client.post("/networks/merge", data=jsondata, headers=headers) + assert response.status_code == 200 + + merged_sk = ScopedKey(**response.json()) + assert merged_sk.scope == scope_test + assert merged_sk.gufe_key.startswith("AlchemicalNetwork-") + + # network should now be present in the database + assert n4js.check_existence(merged_sk) + + # merged network's union-of-edges equals network_tyk2's edge set, + # since one of the sources is network_tyk2 and the other is a strict subset + merged_network = n4js.get_gufe(merged_sk) + assert merged_network.name == "api_merged" + assert {t.key for t in merged_network.edges} == { + t.key for t in network_tyk2.edges + } + + def test_merge_networks_bad_scope( + self, n4js_preloaded, test_client, scope_test, multiple_scopes + ): + # destination scope the test_client's token does not have access to + bad_scope = multiple_scopes[1] + assert bad_scope != scope_test + + source_sks = n4js_preloaded.query_networks(scope=scope_test) + assert source_sks + + headers = {"Content-type": "application/json"} + data = dict( + networks=[str(sk) for sk in source_sks], + name="should_fail", + scope=bad_scope.to_dict(), + ) + jsondata = json.dumps(data, cls=JSON_HANDLER.encoder) + + response = test_client.post("/networks/merge", data=jsondata, headers=headers) + assert response.status_code == 401 + details = response.json() + assert "detail" in details + assert str(bad_scope) in details["detail"] + + def test_merge_networks_bad_source_scope( + self, n4js_preloaded, test_client, scope_test, multiple_scopes + ): + # source network in a scope the test_client's token does not authorize + unauth_scope = multiple_scopes[1] + assert unauth_scope != scope_test + + unauth_source_sks = n4js_preloaded.query_networks(scope=unauth_scope) + assert unauth_source_sks + + headers = {"Content-type": "application/json"} + data = dict( + networks=[str(sk) for sk in unauth_source_sks], + name="should_fail", + scope=scope_test.to_dict(), + ) + jsondata = json.dumps(data, cls=JSON_HANDLER.encoder) + + response = test_client.post("/networks/merge", data=jsondata, headers=headers) + assert response.status_code == 401 + details = response.json() + assert "detail" in details + assert str(unauth_scope) in details["detail"] + def test_get_network(self, prepared_network, test_client): network, scoped_key = prepared_network response = test_client.get(f"/networks/{scoped_key}") diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index fa35617b..4af8addd 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -138,6 +138,149 @@ def test_create_overlapping_networks(self, n4js, network_tyk2, scope_test): def test_delete_network(self): raise NotImplementedError + def test_merge_networks(self, n4js, network_tyk2, scope_test): + network_sks = [] + + scope_args = scope_test.to_dict() + all_transformations = list(network_tyk2.edges) + all_transformations.sort(key=lambda x: x.key) + + def project_scope(iteration): + return Scope(**(scope_args | {"project": f"project{iteration}"})) + + transformations_common = all_transformations[:3] + + # note the nonlocal task_sks and transformation_sks to reduce + # clutter down below + def set_result(_slice, *, ok: bool): + nonlocal task_sks, transformation_sks + for task_sk in task_sks[_slice]: + pdrr = ProtocolDAGResultRef( + obj_key=f"ProtocolDAGResult-{uuid.uuid4()}", + scope=task_sk.scope, + ok=ok, + ) + n4js.set_task_result(task_sk, pdrr) + + # NETWORK 1---JUST TYK2 + # 5 tasks: [completed, completed, running, error, waiting] + # no extends + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + f"_test_network_clone_1" + ) + scope = project_scope(1) + sk, th_sk, _ = n4js.assemble_network(an, scope) + network_sks.append(sk) + # add two more of the transformations + transformation_sks = [ + n4js.get_scoped_key(transformation, scope) + for transformation in transformations_common + all_transformations[3:5] + ] + task_sks = n4js.create_tasks(transformation_sks) + n4js.set_task_running(task_sks[2:3]) + n4js.action_tasks(task_sks[2:3], th_sk) + + n4js.set_task_waiting(task_sks[5:]) + n4js.action_tasks(task_sks[5:], th_sk) + + n4js.set_task_running(task_sks[4:5]) + n4js.set_task_error(task_sks[4:5]) + n4js.action_tasks(task_sks[4:5], th_sk) + + n4js.set_task_running(task_sks[:2]) + n4js.set_task_complete(task_sks[:2]) + + set_result(slice(None, 2), ok=True) + set_result(slice(4, 5), ok=False) + + # NETWORK 2---TYK2 WITH TRIMMED EDGES + # 6 tasks: [complete, complete, complete] + # ^ ^ ^ + # | | | + # extends: [waiting, running, complete] + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + f"_test_network_clone_2_fewer_transformations", + edges=all_transformations[:-2], + ) + scope = project_scope(2) + sk, _, _ = n4js.assemble_network(an, scope) + network_sks.append(sk) + transformation_sks = [ + n4js.get_scoped_key(transformation, scope) + for transformation in transformations_common + ] + task_sks = n4js.create_tasks(transformation_sks) + n4js.set_task_running(task_sks) + n4js.set_task_complete(task_sks) + set_result(slice(None), ok=True) + + # create extending tasks + task_sks = n4js.create_tasks(transformation_sks, extends=task_sks) + + n4js.set_task_waiting(task_sks[0:1]) + n4js.set_task_running(task_sks[1:2]) + n4js.set_task_running(task_sks[2:3]) + n4js.set_task_complete(task_sks[2:3]) + + set_result(slice(2, 3), ok=True) + + # network 3---tyk2 + # 6 tasks: [waiting, waiting, waiting, waiting, waiting, complete] + # no extends + # last task is for a transformation missing from NETWORK 2 + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + f"_test_network_clone_2_name_only" + ) + scope = project_scope(3) + sk, _, _ = n4js.assemble_network(an, scope) + network_sks.append(sk) + + transformation_sks = [ + n4js.get_scoped_key(transformation, scope) + for transformation in transformations_common + all_transformations[-2:] + ] + + task_sks = n4js.create_tasks(transformation_sks) + n4js.set_task_waiting(task_sks[:-1]) + n4js.set_task_running(task_sks[-1:]) + n4js.set_task_complete(task_sks[-1:]) + set_result(slice(-1, None), ok=True) + + scope_dict = scope_test.to_dict() + scope_dict["project"] = "mergedproject" + new_scope = Scope(**scope_dict) + sk_merged = n4js.merge_networks( + network_sks, f"{network_tyk2.name}_combined", new_scope + ) + assert len(n4js.get_gufe(sk_merged).edges) == len( + n4js.get_gufe(network_sks[0]).edges + ) + assert len(n4js.get_gufe(sk_merged).edges) != len( + n4js.get_gufe(network_sks[1]).edges + ) + + # we expect 7 pdrrs from the completed tasks + results = n4js.execute_query( + """ + MATCH (pdrr: ProtocolDAGResultRef {`_project`: $project}) + WHERE pdrr.ok = True + RETURN pdrr + """, + project=new_scope.project, + ) + assert len(results.records) == 7 + + # we expect 1 pdrrs from the errored task + results = n4js.execute_query( + """ + MATCH (pdrr: ProtocolDAGResultRef {`_project`: $project}) + WHERE pdrr.ok = False + RETURN pdrr + """, + project=new_scope.project, + ) + assert len(results.records) == 1 + def test_set_network_state(self, n4js, network_tyk2, scope_test): valid_states = [state.value for state in NetworkStateEnum] network_sks = [] diff --git a/devtools/conda-envs/alchemiscale-client.yml b/devtools/conda-envs/alchemiscale-client.yml index 6ef6d801..8a536682 100644 --- a/devtools/conda-envs/alchemiscale-client.yml +++ b/devtools/conda-envs/alchemiscale-client.yml @@ -8,7 +8,7 @@ dependencies: - cuda-version >=12 # alchemiscale dependencies - - gufe=1.6.1 + - gufe=1.8.0 - openfe=1.6.1 - requests - click diff --git a/devtools/conda-envs/alchemiscale-compute.yml b/devtools/conda-envs/alchemiscale-compute.yml index 92b9d4be..6d8cd255 100644 --- a/devtools/conda-envs/alchemiscale-compute.yml +++ b/devtools/conda-envs/alchemiscale-compute.yml @@ -8,7 +8,7 @@ dependencies: - cuda-version >=12 # alchemiscale dependencies - - gufe=1.6.1 + - gufe=1.8.0 - openfe=1.6.1 - requests - click diff --git a/devtools/conda-envs/alchemiscale-server.yml b/devtools/conda-envs/alchemiscale-server.yml index fa6f22b3..6fea2592 100644 --- a/devtools/conda-envs/alchemiscale-server.yml +++ b/devtools/conda-envs/alchemiscale-server.yml @@ -8,7 +8,7 @@ dependencies: - cuda-version >=12 # alchemiscale dependencies - - gufe=1.6.1 + - gufe=1.8.0 - openfe=1.6.1 - requests - click diff --git a/devtools/conda-envs/docs.yml b/devtools/conda-envs/docs.yml index f080a0af..69f4c6d2 100644 --- a/devtools/conda-envs/docs.yml +++ b/devtools/conda-envs/docs.yml @@ -10,6 +10,6 @@ dependencies: - myst-parser>=0.14 - docutils - sphinx-notfound-page - - gufe=1.3.0 + - gufe=1.8.0 - py2neo - stratocaster diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index 9e0441c3..f4f6b5e7 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -7,7 +7,7 @@ dependencies: - cuda-version >=12 # alchemiscale dependencies - - gufe =1.7.1 + - gufe =1.8.0 - openfe =1.8.0 - pydantic >2 - pydantic-settings diff --git a/news/issue-221.rst b/news/issue-221.rst new file mode 100644 index 00000000..75a23c91 --- /dev/null +++ b/news/issue-221.rst @@ -0,0 +1,24 @@ +**Added:** + +* ``AlchemiscaleClient.merge_networks`` for combining multiple existing ``AlchemicalNetwork``\s into a new ``AlchemicalNetwork``, preserving completed and errored ``Task`` results so they do not need to be re-run. Backed by ``Neo4jStore.merge_networks`` and a new ``POST /networks/merge`` endpoint on the user API. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* +