Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ jobs:
run: poetry run make check-format

- name: Run mypy
# TODO: Remove continue-on-error after fixing openpulse 1.0.0 type incompatibilities
continue-on-error: true

@ajberdy ajberdy Jan 15, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is a change we'd want to merge to main. If so, I'd recommend linking a tracking issue in this TODO

And I'd be more tempted to pin the mypy version if possible rather than making a successful type check optional-- this creates a blind spot to new type errors not introduced by the upgrade.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to update mypy? It should already be "pinned" by the poetry.lock file. We should be able to update openpulse without updating mypy.

run: poetry run make check-mypy

- name: Run pylint
Expand Down
110 changes: 108 additions & 2 deletions oqpy/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,17 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Iterable, Iterator, Optional, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
ContextManager,
Iterable,
Iterator,
Literal,
Optional,
TypeVar,
overload,
)

from openpulse import ast

Expand All @@ -38,7 +48,7 @@
from oqpy.program import Program


__all__ = ["If", "Else", "ForIn", "While", "Range"]
__all__ = ["If", "Else", "ForIn", "While", "Range", "Switch", "Case", "Default"]


@contextlib.contextmanager
Expand Down Expand Up @@ -176,3 +186,99 @@ def While(program: Program, condition: OQPyExpression) -> Iterator[None]:
yield
state = program._pop()
program._add_statement(ast.WhileLoop(to_ast(program, condition), state.body))


class Switch(ContextManager["Switch"]):
"""Context manager for switch statement control flow.

.. code-block:: python

selector = IntVar(0)
with Switch(program, selector) as switch:
with Case(switch, 0):
program.increment(result, 1)
with Case(switch, 1, 2): # Multiple values in one case
program.increment(result, 2)
with Default(switch):
program.increment(result, 100)

"""

def __init__(self, program: "Program", target: OQPyExpression):
self.program = program
self.target = target
self.cases: list[tuple[list[ast.Expression], list[ast.Statement]]] = []
self.default: list[ast.Statement] | None = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense for Oqpy in general to allow an empty default, to give full flexibility as per the OpenQASM spec. However, the OpenQASM ast implementation notes

# Note that `None` is quite different to `[]` in this case; the latter is
# an explicitly empty body, whereas the absence of a default might mean
# that the switch is inexhaustive, and a linter might want to complain.

Do we want to add an optional flag to this class to toggle whether a None default is allowed? Or should that be handled by oqpy's consumers?

Two other options on the table are defaulting to an empty block (perhaps risky) or raising an error/warning if no default is given (perhaps annoying)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None default seems like sane behavior since the produced openqasm is closest to what was written. I'm not sure openqasm needs to be the one enforcing default behavior.


def __enter__(self) -> "Switch":
return self

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: Any,
) -> Literal[False]:
if exc_type is not None:
return False
Comment on lines +225 to +227

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to google what the return value of a context manager's __exit__ function means (whether to suppress exceptions within the block), so perhaps a comment here would help future maintainers, but maybe it's relatively common knowledge/easy to track down.

On the other hand, the default behavior (i.e. when returning None) is to not suppress errors, so maybe we don't need a return value/annotation at all here. If we're raising an exception anyways, we probably don't care whether the statement gets added to the program, so we could remove this branch entirely, though theoretically if we wanted to exit a couple clock cycles early, we could keep this line without the return value

# Build the case tuples as (list of expressions, CompoundStatement)
case_tuples = [(values, ast.CompoundStatement(body)) for values, body in self.cases]
default_stmt = ast.CompoundStatement(self.default) if self.default else None
stmt = ast.SwitchStatement(
to_ast(self.program, self.target),
case_tuples,
default_stmt,
)
self.program._add_statement(stmt)
return False


@contextlib.contextmanager
def Case(switch: Switch, *values: AstConvertible) -> Iterator[None]:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing in the switch leaves some room for shenanigans, where the case statement is not directly within the switch context. My suggestion is we

  1. In Switch.__enter__ do a Program._push()
  2. update the pushed ProgramState to add a active_switch_statement field.
  3. Pass the program into Case
  4. In Case, get the switch from the program state (fail if not present).
  5. In Switch.__exit__ do Program._pop

So the usage would look more like:

with Switch(program, selector):
    with Case(program, 0):
        ...

We can also enforce that only case statements appear within a Switch context by raising an error if active_switch_statement is not None in ProgramState.add_statement. This would prevent usages like:

with Switch(program, selector):
    program.play(frame, waveform) # or any command outside of a case statement

"""Context manager for a case within a switch statement.

Must be used inside a Switch context. Multiple values can be provided
for a single case block.

.. code-block:: python

with Switch(program, selector) as switch:
with Case(switch, 0):
# Handle case 0
program.increment(result, 1)
with Case(switch, 1, 2):
# Handle cases 1 and 2
program.increment(result, 2)

"""
if not values:
raise ValueError("Case requires at least one value")
switch.program._push()
yield
state = switch.program._pop()
case_values = [to_ast(switch.program, v) for v in values]
switch.cases.append((case_values, state.body))


@contextlib.contextmanager
def Default(switch: Switch) -> Iterator[None]:
"""Context manager for the default case within a switch statement.

Must be used inside a Switch context.

.. code-block:: python

with Switch(program, selector) as switch:
with Case(switch, 0):
program.increment(result, 1)
with Default(switch):
# Handle all other cases
program.increment(result, 100)

"""
if switch.default is not None:
raise RuntimeError("Switch statement can only have one default case")
switch.program._push()
yield
state = switch.program._pop()
switch.default = state.body
19 changes: 16 additions & 3 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ def __iadd__(self, other: Program) -> Program:
self.defcals.update(other.defcals)
for name, subroutine_stmt in other.subroutines.items():
self._add_subroutine(
name, subroutine_stmt, needs_declaration=name not in other.declared_subroutines
name,
subroutine_stmt,
needs_declaration=name not in other.declared_subroutines,
)
for name, gate_stmt in other.gates.items():
self._add_gate(name, gate_stmt, needs_declaration=name not in other.declared_gates)
Expand Down Expand Up @@ -418,7 +420,9 @@ def declare(
return self

def delay(
self, time: AstConvertible, qubits_or_frames: AstConvertible | Iterable[AstConvertible] = ()
self,
time: AstConvertible,
qubits_or_frames: AstConvertible | Iterable[AstConvertible] = (),
Comment on lines +431 to +433

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like your formatter is set to 80 lines, but pyproject.toml specifies 100, perhaps we can revert this and the below changes

) -> Program:
"""Apply a delay to a set of qubits or frames."""
if not isinstance(qubits_or_frames, Iterable):
Expand Down Expand Up @@ -608,7 +612,9 @@ def reset(self, qubit: quantum_types.Qubit) -> Program:
return self

def measure(
self, qubit: quantum_types.Qubit, output_location: classical_types.BitVar | None = None
self,
qubit: quantum_types.Qubit,
output_location: classical_types.BitVar | None = None,
) -> Program:
"""Measure a particular qubit.

Expand Down Expand Up @@ -709,6 +715,13 @@ def visit_BranchingStatement(self, node: ast.BranchingStatement, context: None =
node.else_block = self.process_statement_list(node.else_block)
self.generic_visit(node, context)

def visit_SwitchStatement(self, node: ast.SwitchStatement, context: None = None) -> None:
for _, case_block in node.cases:
case_block.statements = self.process_statement_list(case_block.statements)
if node.default:
node.default.statements = self.process_statement_list(node.default.statements)
self.generic_visit(node, context)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smallest nit (feel free to ignore) - if node.default is not None feels like a tighter condition here in terms of symmetry with the non-default cases. If node.default.statements was explicitly given as [], a symmetric approach would trivially process that, as it would for an empty case block.

Of course, the result is the same, but being explicit about why we would skip processing the default (it's set to None, vs happens to be empty) ever so slightly reduces mental load for developers reading the code.


def visit_CalibrationStatement(
self, node: ast.CalibrationStatement, context: None = None
) -> None:
Expand Down
Loading
Loading