Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions cirq-google/cirq_google/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import random
import string
from collections.abc import Mapping, Sequence
from http import HTTPStatus
from typing import TYPE_CHECKING, TypeVar

import duet
Expand Down Expand Up @@ -391,6 +392,33 @@ async def run_sweep_async(
`processor_id` is empty.
"""

async def create_program_and_job() -> engine_job.EngineJob:
try:
engine_program = await self.create_program_async(
program, program_id, description=program_description, labels=program_labels
)
except engine_client.EngineException as ee:
if ee.code == HTTPStatus.CONFLICT:
if not program_id:
# Randomly-assigned ID collided with existing
raise
# If the program was already created, move on to job creation.
engine_program = self.get_program(program_id)
else:
raise

return await engine_program.run_sweep_async(
job_id=job_id,
params=params,
repetitions=repetitions,
processor_id=processor_id,
description=job_description,
labels=job_labels,
run_name=run_name,
snapshot_id=snapshot_id,
device_config_name=device_config_name,
)

if self.context.enable_streaming:
if not program_id:
program_id = _make_random_id('prog-')
Expand Down Expand Up @@ -424,22 +452,10 @@ async def run_sweep_async(
str(job_id),
self.context,
job_result_future=job_result_future,
recreate_job=create_program_and_job,
)

engine_program = await self.create_program_async(
program, program_id, description=program_description, labels=program_labels
)
return await engine_program.run_sweep_async(
job_id=job_id,
params=params,
repetitions=repetitions,
processor_id=processor_id,
description=job_description,
labels=job_labels,
run_name=run_name,
snapshot_id=snapshot_id,
device_config_name=device_config_name,
)
return await create_program_and_job()

run_sweep = duet.sync(run_sweep_async)

Expand Down
10 changes: 8 additions & 2 deletions cirq-google/cirq_google/engine/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,15 @@ def _fix_deprecated_allowlisted_users_args(


class EngineException(Exception):
def __init__(self, message):
def __init__(self, code, message):
# Call the base class constructor with the parameters it needs
super().__init__(message)
self._code = code

@property
def code(self):
"""The HTTP status code which caused this exception."""
return self._code


RETRYABLE_ERROR_CODES = [500, 503]
Expand Down Expand Up @@ -148,7 +154,7 @@ async def _run_retry_async(self, func: Callable[[_M], Awaitable[_R]], request: _
# Raise RuntimeError for exceptions that are not retryable.
# Otherwise, pass through to retry.
if err.code not in RETRYABLE_ERROR_CODES:
raise EngineException(message) from err
raise EngineException(err.code, message) from err

if current_delay > self.max_retry_delay_seconds:
raise TimeoutError(f'Reached max retry attempts for error: {message}')
Expand Down
39 changes: 33 additions & 6 deletions cirq-google/cirq_google/engine/engine_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from __future__ import annotations

import datetime
from collections.abc import Sequence
from collections.abc import Awaitable, Callable, Sequence
from http import HTTPStatus
from typing import TYPE_CHECKING

import duet
Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
job_result_future: (
duet.AwaitableFuture[quantum.QuantumResult | quantum.QuantumJob] | None
) = None,
recreate_job: Callable[[], Awaitable[EngineJob]] | None = None,
) -> None:
"""A job submitted to the engine.

Expand All @@ -94,6 +96,7 @@ def __init__(
self._results: Sequence[EngineResult] | None = None
self._batched_results: Sequence[Sequence[EngineResult]] | None = None
self._job_result_future = job_result_future
self._recreate_job = recreate_job

def id(self) -> str:
"""Returns the job id."""
Expand Down Expand Up @@ -368,17 +371,41 @@ async def _await_result_async(self) -> quantum.QuantumResult:
# If the stream has disconnected, attempt to retrieve the result without it.
pass

try:
self._job = await self._await_completion_by_polling()
except engine_client.EngineException as e:
if e.code == HTTPStatus.NOT_FOUND and self._recreate_job:
# If the program/job was not created successfully, attempt to recreate once.
new_job = await self._recreate_job()

self.project_id = new_job.project_id
self.program_id = new_job.program_id
self.job_id = new_job.job_id
self.context = new_job.context
self._job = new_job._job
self._results = new_job._results
self._batched_results = new_job._batched_results
self._job_result_future = new_job._job_result_future
self._recreate_job = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there potential for a race condition here? The variable self._recreate_job is tested, then a job is started, some assignments happen, and only afterwards is the tested variable set to None. If more than one asynchronous task hits this code at the same time, there seems to be potential for multiple jobs to be created.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. I've moved the reassignment to None to the earliest point possible.


self._job = await self._await_completion_by_polling()
else:
raise

_raise_on_failure(self._job)
response = await self.context.client.get_job_results_async(
self.project_id, self.program_id, self.job_id
)
return response

async def _await_completion_by_polling(self) -> quantum.QuantumJob:
async with duet.timeout_scope(self.context.timeout): # type: ignore[arg-type]
while True:
job = await self._refresh_job_async()
if job.execution_status.state in TERMINAL_STATES:
break
await duet.sleep(1)
_raise_on_failure(job)
response = await self.context.client.get_job_results_async(
self.project_id, self.program_id, self.job_id
)
return response
return job

def _get_job_results_v1(self, result: v1.program_pb2.Result) -> Sequence[EngineResult]:
job_id = self.id()
Expand Down
75 changes: 74 additions & 1 deletion cirq-google/cirq_google/engine/engine_job_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import datetime
from http import HTTPStatus
from unittest import mock

import duet
Expand All @@ -26,7 +27,7 @@
import cirq_google as cg
from cirq_google.api import v1, v2
from cirq_google.cloud import quantum
from cirq_google.engine import util
from cirq_google.engine import EngineException, util
from cirq_google.engine.engine import EngineContext
from cirq_google.engine.stream_manager import StreamError

Expand Down Expand Up @@ -799,6 +800,78 @@ def test_on_stream_failure_retrieves_results_using_get_job_results(get_job_resul
get_job_results.assert_called_once_with('a', 'b', 'steve')


@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async')
@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async')
def test_recreate_job_if_not_found(get_job_results, get_job):
project_id = 'a'
program_id = 'b'
job_id = 'steve'
context = EngineContext(timeout=60, enable_streaming=False)

get_job.side_effect = EngineException(HTTPStatus.NOT_FOUND, 'job not found')

async def recreate_job():
qjob = quantum.QuantumJob(
execution_status=quantum.ExecutionStatus(state=quantum.ExecutionStatus.State.SUCCESS),
update_time=UPDATE_TIME,
)
get_job.side_effect = None
get_job.return_value = qjob
get_job_results.return_value = RESULTS
return cg.EngineJob(
project_id=project_id,
program_id=program_id,
job_id=job_id,
context=context,
_job=qjob,
job_result_future=None,
recreate_job=None,
)

job = cg.EngineJob(
project_id=project_id,
program_id=program_id,
job_id=job_id,
context=context,
_job=None,
job_result_future=None,
recreate_job=recreate_job,
)
data = job.results()

assert len(data) == 2
assert str(data[0]) == 'q=0110'
assert str(data[1]) == 'q=1010'
get_job.assert_has_calls((mock.call(project_id, program_id, job_id, False),))
get_job_results.assert_called_once_with(project_id, program_id, job_id)


@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_async')
@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async')
def test_receive_results_get_job_error_propagated(get_job_results, get_job):
project_id = 'a'
program_id = 'b'
job_id = 'steve'
context = EngineContext(timeout=60, enable_streaming=False)

get_job.side_effect = EngineException(HTTPStatus.INTERNAL_SERVER_ERROR, 'internal error')

job = cg.EngineJob(
project_id=project_id,
program_id=program_id,
job_id=job_id,
context=context,
_job=None,
job_result_future=None,
)

try:
job.results()
except Exception as e:
assert isinstance(e, EngineException)
assert e.code == HTTPStatus.INTERNAL_SERVER_ERROR

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In cases like this, this should probably make use of Pytest's more idiomatic pytest.raises. E.g.,

Suggested change
try:
job.results()
except Exception as e:
assert isinstance(e, EngineException)
assert e.code == HTTPStatus.INTERNAL_SERVER_ERROR
with pytest.raises(EngineException) as exc_info:
job.results()
assert exc_info.value.code == HTTPStatus.INTERNAL_SERVER_ERROR

There is more than one try-except case like this in the PR, so if you agree with this change, please also check the other places.

@hoisinberg hoisinberg Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I thought there was some way to do this but I couldn't remember at the time. Changed.



@mock.patch('cirq_google.engine.engine_client.EngineClient.get_job_results_async')
def test_results_len(get_job_results):
qjob = quantum.QuantumJob(
Expand Down
81 changes: 80 additions & 1 deletion cirq-google/cirq_google/engine/engine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import datetime
import time
from http import HTTPStatus
from unittest import mock

import duet
Expand All @@ -31,7 +32,7 @@
import cirq_google as cg
from cirq_google.api import v1, v2
from cirq_google.cloud import quantum
from cirq_google.engine import util
from cirq_google.engine import EngineException, util
from cirq_google.engine.engine import EngineContext
from cirq_google.engine.processor_config import Run, Snapshot

Expand Down Expand Up @@ -573,6 +574,84 @@ def test_run_sweep_params_with_unary_rpcs(client):
client().get_job_results_async.assert_called_once()


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_program_already_exists(client):
program_id = 'prog'
client().create_program_async.side_effect = [
EngineException(HTTPStatus.CONFLICT, "program already exists"),
(program_id, quantum.QuantumProgram(name=f"projects/proj/programs/{program_id}")),

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an exception does occur in create_program_async, is this line ever executed?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, and in fact this test even asserts that it is only called once. Removed.

]
client().create_job_async.return_value = (
'job-id',
quantum.QuantumJob(
name=f"projects/proj/programs/{program_id}/jobs/job-id",
execution_status={'state': 'READY'},
),
)
client().get_job_async.return_value = quantum.QuantumJob(
execution_status={'state': 'SUCCESS'}, update_time=_DT
)
client().get_job_results_async.return_value = quantum.QuantumResult(result=_RESULTS)

engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))
job = engine.run_sweep(
program=_CIRCUIT,
program_id=program_id,
processor_id='processor0',
params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})],
)
results = job.results()

assert len(results) == 2
for i, v in enumerate([1, 2]):
assert results[i].repetitions == 1
assert results[i].params.param_dict == {'a': v}
assert results[i].measurements == {'q': np.array([[0]], dtype='uint8')}

client().create_program_async.assert_called_once()
client().create_job_async.assert_called_once()
client().get_job_async.assert_called_once()
client().get_job_results_async.assert_called_once()


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_program_with_implicit_id_already_exists(client):
client().create_program_async.side_effect = EngineException(
HTTPStatus.CONFLICT, "program already exists"
)
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))

try:
engine.run_sweep(
program=_CIRCUIT,
processor_id='processor0',
params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})],
)
except Exception as e:
assert isinstance(e, EngineException)
assert e.code == HTTPStatus.CONFLICT


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_unable_to_create_program_raises_error(client):
program_id = 'prog'
client().create_program_async.side_effect = EngineException(
HTTPStatus.INTERNAL_SERVER_ERROR, "internal error"
)
engine = cg.Engine(project_id='proj', context=EngineContext(enable_streaming=False))

try:
engine.run_sweep(
program=_CIRCUIT,
program_id=program_id,
processor_id='processor0',
params=[cirq.ParamResolver({'a': 1}), cirq.ParamResolver({'a': 2})],
)
except Exception as e:
assert isinstance(e, EngineException)
assert e.code == HTTPStatus.INTERNAL_SERVER_ERROR


@mock.patch('cirq_google.engine.engine_client.EngineClient', autospec=True)
def test_run_sweep_params_with_stream_rpcs(client):
setup_run_circuit_with_result_(client, _RESULTS)
Expand Down
Loading