Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion alchemiscale/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import json
from urllib.parse import urljoin
from functools import wraps
from datetime import datetime

import requests
from requests.auth import HTTPBasicAuth
Expand Down Expand Up @@ -93,7 +94,12 @@ def get_task_transformation(
)

def set_task_result(
self, task: ScopedKey, protocoldagresult: ProtocolDAGResult
self,
task: ScopedKey,
protocoldagresult: ProtocolDAGResult,
compute_service_id: Optional[ComputeServiceID] = None,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
) -> ScopedKey:
data = dict(
protocoldagresult=json.dumps(
Expand Down
19 changes: 16 additions & 3 deletions alchemiscale/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from pathlib import Path
from threading import Thread
import tempfile
from datetime import datetime

import requests

Expand Down Expand Up @@ -243,15 +244,19 @@ def task_to_protocoldag(
return protocoldag, transformation, extends_protocoldagresult

def push_result(
self, task: ScopedKey, protocoldagresult: ProtocolDAGResult
self, task: ScopedKey, protocoldagresult: ProtocolDAGResult,
start: datetime, end: datetime
) -> ScopedKey:
# TODO: this method should postprocess any paths,
# leaf nodes in DAG for blob results that should go to object store

# TODO: ship paths to object store

# finally, push ProtocolDAGResult
sk: ScopedKey = self.client.set_task_result(task, protocoldagresult)
sk: ScopedKey = self.client.set_task_result(task, protocoldagresult,
compute_service_id=self.compute_service_id,
start=start,
end=end)

return sk

Expand All @@ -271,6 +276,8 @@ def execute(self, task: ScopedKey) -> ScopedKey:
)
shared = Path(shared_tmp.name)

start = datetime.utcnow()

protocoldagresult = execute_DAG(
protocoldag,
shared=shared,
Expand All @@ -279,11 +286,17 @@ def execute(self, task: ScopedKey) -> ScopedKey:
raise_error=False,
)

end = datetime.utcnow()

if not self.keep_shared:
shared_tmp.cleanup()

# push the result (or failure) back to the compute API
result_sk = self.push_result(task, protocoldagresult)
result_sk = self.push_result(task,
protocoldagresult,
start=start,
end=end
)

return result_sk

Expand Down
19 changes: 0 additions & 19 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,21 +191,6 @@ def _defaults(cls):
return super()._defaults()


class TaskArchive(GufeTokenizable):
...

def _to_dict(self):
return {}

@classmethod
def _from_dict(cls, d):
return cls(**d)

@classmethod
def _defaults(cls):
return super()._defaults()


class ObjectStoreRef(GufeTokenizable):
location: Optional[str]
obj_key: Optional[GufeKey]
Expand Down Expand Up @@ -257,7 +242,3 @@ def _to_dict(self):
"scope": str(self.scope),
"ok": self.ok,
}


class TaskArchive(GufeTokenizable):
...
26 changes: 23 additions & 3 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from .models import (
ComputeServiceID,
ComputeServiceRegistration,
TaskProvenance,
Task,
TaskHub,
TaskArchive,
TaskStatusEnum,
ProtocolDAGResultRef,
)
Expand Down Expand Up @@ -1570,9 +1570,17 @@ def get_task_transformation(
return transformation, protocoldagresultref

def set_task_result(
self, task: ScopedKey, protocoldagresultref: ProtocolDAGResultRef
self,
task: ScopedKey,
protocoldagresultref: ProtocolDAGResultRef,
taskprovenance: Optional[TaskProvenance] = None
) -> ScopedKey:
"""Set a `ProtocolDAGResultRef` pointing to a `ProtocolDAGResult` for the given `Task`."""
"""Set a `ProtocolDAGResultRef` pointing to a `ProtocolDAGResult` for the given `Task`.

If a `TaskProvenance` is given, this will also be associated with the
`ProtocolDAGResultRef` via a RECORDS relationship.

"""

if task.qualname != "Task":
raise ValueError("`task` ScopedKey does not correspond to a `Task`")
Expand All @@ -1595,6 +1603,18 @@ def set_task_result(
_project=scope.project,
)

if taskprovenance is not None:
taskprovenance_node = Node(
"TaskProvenance", **taskprovenance.to_dict()
)
subgraph = subgraph | Relationship.type("RECORDS")(
taskprovenance_node,
protocoldagresultref_node,
_org=scope.org,
_campaign=scope.campaign,
_project=scope.project,
)

with self.transaction() as tx:
tx.merge(subgraph, "GufeTokenizable", "_scoped_key")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,9 @@ def test_get_task_results(
assert isinstance(pdr.extends_key, GufeKey) or pdr.extends_key is None
assert pdr.ok()

import pdb
pdb.set_trace()

def test_get_task_failures(
self,
scope_test,
Expand Down