Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 20 additions & 19 deletions oqpy/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import warnings
from copy import deepcopy
from typing import Any, Hashable, Iterable, Iterator, Optional
from typing import Any, Hashable, Iterable, Iterator, Optional, cast

from openpulse import ast
from openpulse.printer import dumps
Expand Down Expand Up @@ -499,8 +499,8 @@ def returns(self, expression: AstConvertible) -> Program:

def _create_modifiers_ast(
self,
control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None,
neg_control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None,
control: AstConvertible | Iterable[AstConvertible] | None,
neg_control: AstConvertible | Iterable[AstConvertible] | None,
inv: bool,
exp: AstConvertible,
) -> tuple[list[ast.QuantumGateModifier], list[AstConvertible]]:
Expand All @@ -509,28 +509,30 @@ def _create_modifiers_ast(
modifiers: list[ast.QuantumGateModifier] = []

control = control if control is not None else []
control = {control} if isinstance(control, quantum_types.Qubit) else set(control)
if not isinstance(control, Iterable):
control = [control]
control = set(control)
if control:
modifiers.append(
ast.QuantumGateModifier(
modifier=ast.GateModifierName.ctrl,
argument=to_ast(self, len(control)) if len(control) > 1 else None,
)
)
used_qubits.extend(sorted(control))
used_qubits.extend(sorted(cast("Iterable[Any]", control)))

neg_control = neg_control if neg_control is not None else []
neg_control = (
{neg_control} if isinstance(neg_control, quantum_types.Qubit) else set(neg_control)
)
if not isinstance(neg_control, Iterable):
neg_control = [neg_control]
neg_control = set(neg_control)
if neg_control:
modifiers.append(
ast.QuantumGateModifier(
modifier=ast.GateModifierName.negctrl,
argument=to_ast(self, len(neg_control)) if len(neg_control) > 1 else None,
)
)
for qubit in sorted(neg_control):
for qubit in sorted(cast("Iterable[Any]", neg_control)):
if qubit in used_qubits:
raise ValueError(f"Qubit {qubit} has already been defined as a control qubit.")
else:
Expand All @@ -556,8 +558,8 @@ def gate(
qubits: AstConvertible | Iterable[AstConvertible],
name: str,
*args: Any,
control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None = None,
neg_control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None = None,
control: AstConvertible | Iterable[AstConvertible] | None = None,
neg_control: AstConvertible | Iterable[AstConvertible] | None = None,
inv: bool = False,
exp: AstConvertible = 1,
) -> Program:
Expand All @@ -568,9 +570,9 @@ def gate(
to which the gate will be applied
name (str): The gate name
*args (Any): A list of parameters passed to the gate
control (quantum_types.Qubit | Iterable[quantum_types.Qubit] | None): The list
control (AstConvertible | Iterable[AstConvertible] | None): The list
of control qubits (default: None)
neg_control: (quantum_types.Qubit | Iterable[quantum_types.Qubit] | None): The list
neg_control: (AstConvertible | Iterable[AstConvertible] | None): The list
of negative control qubits (default: None)
inv (bool): Flag to use the inverse gate (default: False)
exp (AstConvertible): The exponent used with `pow` gate modifier
Expand All @@ -580,9 +582,8 @@ def gate(
"""
modifiers, used_qubits = self._create_modifiers_ast(control, neg_control, inv, exp)

if isinstance(qubits, (quantum_types.Qubit, quantum_types.IndexedQubitArray)):
if not isinstance(qubits, Iterable):
qubits = [qubits]
assert isinstance(qubits, Iterable)

for qubit in qubits:
if qubit in used_qubits:
Expand All @@ -602,21 +603,21 @@ def gate(
)
return self

def reset(self, qubit: quantum_types.Qubit) -> Program:
def reset(self, qubit: AstConvertible) -> Program:
"""Reset a particular qubit."""
self._add_statement(ast.QuantumReset(qubits=qubit.to_ast(self)))
self._add_statement(ast.QuantumReset(qubits=to_ast(self, qubit)))
return self

def measure(
self, qubit: quantum_types.Qubit, output_location: classical_types.BitVar | None = None
self, qubit: AstConvertible, output_location: classical_types.BitVar | None = None
) -> Program:
"""Measure a particular qubit.

If provided, store the result in the given output location.
"""
self._add_statement(
ast.QuantumMeasurementStatement(
measure=ast.QuantumMeasurement(ast.Identifier(qubit.name)),
measure=ast.QuantumMeasurement(to_ast(self, qubit)),
target=optional_ast(self, output_location),
)
)
Expand Down
16 changes: 9 additions & 7 deletions oqpy/quantum_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def to_ast(self, program: Program) -> ast.IndexedIdentifier:
@contextlib.contextmanager
def gate(
program: Program,
qubits: Union[Qubit, list[Qubit]],
qubits: Union[AstConvertible, list[AstConvertible]],
name: str,
arguments: Optional[list[AstConvertible]] = None,
declare_here: bool = False,
Expand All @@ -121,7 +121,7 @@ def gate(
program.gate(q1, "Rz", theta)
program.gate(q1, "H")
"""
if isinstance(qubits, Qubit):
if not isinstance(qubits, list):
qubits = [qubits]

arguments_ast = []
Expand All @@ -143,10 +143,11 @@ def gate(
yield None
state = program._pop()

qubits_ast = [to_ast(program, q) for q in qubits]
stmt = ast.QuantumGateDefinition(
name=ast.Identifier(name),
arguments=arguments_ast,
qubits=[ast.Identifier(q.name) for q in qubits],
qubits=qubits_ast,
body=state.body,
)
if declare_here:
Expand All @@ -157,7 +158,7 @@ def gate(
@contextlib.contextmanager
def defcal(
program: Program,
qubits: Union[Qubit, list[Qubit]],
qubits: Union[AstConvertible, list[AstConvertible]],
name: str,
arguments: Optional[list[AstConvertible]] = None,
return_type: Optional[ast.ClassicalType] = None,
Expand All @@ -169,7 +170,7 @@ def defcal(
with defcal(program, q1, "X", [AngleVar(name="theta"), oqpy.pi/2], oqpy.bit) as theta:
program.play(frame, waveform)
"""
if isinstance(qubits, Qubit):
if not isinstance(qubits, list):
qubits = [qubits]
assert return_type is None or isinstance(return_type, ast.ClassicalType)

Expand All @@ -195,16 +196,17 @@ def defcal(
yield None
state = program._pop()

qubits_ast = [to_ast(program, q) for q in qubits]
stmt = ast.CalibrationDefinition(
ast.Identifier(name),
arguments_ast,
[ast.Identifier(q.name) for q in qubits],
qubits_ast,
return_type,
state.body,
)
program._add_statement(stmt)
program._add_defcal(
[qubit.name for qubit in qubits], name, [dumps(a) for a in arguments_ast], stmt
[dumps(q) for q in qubits_ast], name, [dumps(a) for a in arguments_ast], stmt
)


Expand Down
59 changes: 59 additions & 0 deletions tests/test_directives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,3 +2794,62 @@ def test_expr_matches_with_numpy_array_attributes():
# Non-matching tuple lengths should return False.
x2.data = ("sweep",)
assert not expr_matches(x1, x2)


def test_qubit_ast_convertible():
@dataclass(frozen=True)
class MyQubit:
# Custom user qubit class converted via _to_oqpy_expression.
index: int

def _to_oqpy_expression(self):
return PhysicalQubits[self.index]

@dataclass(frozen=True)
class DirectQubit:
# Custom user qubit class converted via to_ast directly.
name: str

def to_ast(self, program):
return ast.Identifier(self.name)

q0 = MyQubit(0)
q1 = MyQubit(1)
d = DirectQubit("foo")

prog = oqpy.Program()
prog.reset(q0)
prog.gate(q0, "h")
prog.gate([q0, q1], "cnot")
prog.gate(q1, "x", control=q0)
prog.gate(q1, "x", control=[q0])
prog.measure(q0)
prog.reset(d)
prog.gate(d, "h")

bit = oqpy.BitVar(name="b")
prog.measure(q1, bit)

with oqpy.defcal(prog, [q0, q1], "mycx"):
prog.gate(q0, "h")

expected = textwrap.dedent(
"""
OPENQASM 3.0;
bit b;
reset $0;
h $0;
cnot $0, $1;
ctrl @ x $0, $1;
ctrl @ x $0, $1;
measure $0;
reset foo;
h foo;
b = measure $1;
defcal mycx $0, $1 {
h $0;
}
"""
).strip()
assert prog.to_qasm() == expected
_check_respects_type_hints(prog)
Loading