diff --git a/oqpy/program.py b/oqpy/program.py index 54c6536..e118cc0 100644 --- a/oqpy/program.py +++ b/oqpy/program.py @@ -25,7 +25,7 @@ import warnings from copy import deepcopy -from typing import Any, Hashable, Iterable, Iterator, Optional +from typing import Any, Hashable, Iterable, Iterator, Optional, cast from openpulse import ast from openpulse.printer import dumps @@ -499,8 +499,8 @@ def returns(self, expression: AstConvertible) -> Program: def _create_modifiers_ast( self, - control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None, - neg_control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None, + control: AstConvertible | Iterable[AstConvertible] | None, + neg_control: AstConvertible | Iterable[AstConvertible] | None, inv: bool, exp: AstConvertible, ) -> tuple[list[ast.QuantumGateModifier], list[AstConvertible]]: @@ -509,7 +509,9 @@ def _create_modifiers_ast( modifiers: list[ast.QuantumGateModifier] = [] control = control if control is not None else [] - control = {control} if isinstance(control, quantum_types.Qubit) else set(control) + if not isinstance(control, Iterable): + control = [control] + control = set(control) if control: modifiers.append( ast.QuantumGateModifier( @@ -517,12 +519,12 @@ def _create_modifiers_ast( argument=to_ast(self, len(control)) if len(control) > 1 else None, ) ) - used_qubits.extend(sorted(control)) + used_qubits.extend(sorted(cast("Iterable[Any]", control))) neg_control = neg_control if neg_control is not None else [] - neg_control = ( - {neg_control} if isinstance(neg_control, quantum_types.Qubit) else set(neg_control) - ) + if not isinstance(neg_control, Iterable): + neg_control = [neg_control] + neg_control = set(neg_control) if neg_control: modifiers.append( ast.QuantumGateModifier( @@ -530,7 +532,7 @@ def _create_modifiers_ast( argument=to_ast(self, len(neg_control)) if len(neg_control) > 1 else None, ) ) - for qubit in sorted(neg_control): + for qubit in sorted(cast("Iterable[Any]", neg_control)): if qubit in used_qubits: raise ValueError(f"Qubit {qubit} has already been defined as a control qubit.") else: @@ -556,8 +558,8 @@ def gate( qubits: AstConvertible | Iterable[AstConvertible], name: str, *args: Any, - control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None = None, - neg_control: quantum_types.Qubit | Iterable[quantum_types.Qubit] | None = None, + control: AstConvertible | Iterable[AstConvertible] | None = None, + neg_control: AstConvertible | Iterable[AstConvertible] | None = None, inv: bool = False, exp: AstConvertible = 1, ) -> Program: @@ -568,9 +570,9 @@ def gate( to which the gate will be applied name (str): The gate name *args (Any): A list of parameters passed to the gate - control (quantum_types.Qubit | Iterable[quantum_types.Qubit] | None): The list + control (AstConvertible | Iterable[AstConvertible] | None): The list of control qubits (default: None) - neg_control: (quantum_types.Qubit | Iterable[quantum_types.Qubit] | None): The list + neg_control: (AstConvertible | Iterable[AstConvertible] | None): The list of negative control qubits (default: None) inv (bool): Flag to use the inverse gate (default: False) exp (AstConvertible): The exponent used with `pow` gate modifier @@ -580,9 +582,8 @@ def gate( """ modifiers, used_qubits = self._create_modifiers_ast(control, neg_control, inv, exp) - if isinstance(qubits, (quantum_types.Qubit, quantum_types.IndexedQubitArray)): + if not isinstance(qubits, Iterable): qubits = [qubits] - assert isinstance(qubits, Iterable) for qubit in qubits: if qubit in used_qubits: @@ -602,13 +603,13 @@ def gate( ) return self - def reset(self, qubit: quantum_types.Qubit) -> Program: + def reset(self, qubit: AstConvertible) -> Program: """Reset a particular qubit.""" - self._add_statement(ast.QuantumReset(qubits=qubit.to_ast(self))) + self._add_statement(ast.QuantumReset(qubits=to_ast(self, qubit))) return self def measure( - self, qubit: quantum_types.Qubit, output_location: classical_types.BitVar | None = None + self, qubit: AstConvertible, output_location: classical_types.BitVar | None = None ) -> Program: """Measure a particular qubit. @@ -616,7 +617,7 @@ def measure( """ self._add_statement( ast.QuantumMeasurementStatement( - measure=ast.QuantumMeasurement(ast.Identifier(qubit.name)), + measure=ast.QuantumMeasurement(to_ast(self, qubit)), target=optional_ast(self, output_location), ) ) diff --git a/oqpy/quantum_types.py b/oqpy/quantum_types.py index 7173d96..8cd1095 100644 --- a/oqpy/quantum_types.py +++ b/oqpy/quantum_types.py @@ -107,7 +107,7 @@ def to_ast(self, program: Program) -> ast.IndexedIdentifier: @contextlib.contextmanager def gate( program: Program, - qubits: Union[Qubit, list[Qubit]], + qubits: Union[AstConvertible, list[AstConvertible]], name: str, arguments: Optional[list[AstConvertible]] = None, declare_here: bool = False, @@ -121,7 +121,7 @@ def gate( program.gate(q1, "Rz", theta) program.gate(q1, "H") """ - if isinstance(qubits, Qubit): + if not isinstance(qubits, list): qubits = [qubits] arguments_ast = [] @@ -143,10 +143,11 @@ def gate( yield None state = program._pop() + qubits_ast = [to_ast(program, q) for q in qubits] stmt = ast.QuantumGateDefinition( name=ast.Identifier(name), arguments=arguments_ast, - qubits=[ast.Identifier(q.name) for q in qubits], + qubits=qubits_ast, body=state.body, ) if declare_here: @@ -157,7 +158,7 @@ def gate( @contextlib.contextmanager def defcal( program: Program, - qubits: Union[Qubit, list[Qubit]], + qubits: Union[AstConvertible, list[AstConvertible]], name: str, arguments: Optional[list[AstConvertible]] = None, return_type: Optional[ast.ClassicalType] = None, @@ -169,7 +170,7 @@ def defcal( with defcal(program, q1, "X", [AngleVar(name="theta"), oqpy.pi/2], oqpy.bit) as theta: program.play(frame, waveform) """ - if isinstance(qubits, Qubit): + if not isinstance(qubits, list): qubits = [qubits] assert return_type is None or isinstance(return_type, ast.ClassicalType) @@ -195,16 +196,17 @@ def defcal( yield None state = program._pop() + qubits_ast = [to_ast(program, q) for q in qubits] stmt = ast.CalibrationDefinition( ast.Identifier(name), arguments_ast, - [ast.Identifier(q.name) for q in qubits], + qubits_ast, return_type, state.body, ) program._add_statement(stmt) program._add_defcal( - [qubit.name for qubit in qubits], name, [dumps(a) for a in arguments_ast], stmt + [dumps(q) for q in qubits_ast], name, [dumps(a) for a in arguments_ast], stmt ) diff --git a/tests/test_directives.py b/tests/test_directives.py index 0ef5420..6dbb56a 100644 --- a/tests/test_directives.py +++ b/tests/test_directives.py @@ -2794,3 +2794,62 @@ def test_expr_matches_with_numpy_array_attributes(): # Non-matching tuple lengths should return False. x2.data = ("sweep",) assert not expr_matches(x1, x2) + + +def test_qubit_ast_convertible(): + @dataclass(frozen=True) + class MyQubit: + # Custom user qubit class converted via _to_oqpy_expression. + index: int + + def _to_oqpy_expression(self): + return PhysicalQubits[self.index] + + @dataclass(frozen=True) + class DirectQubit: + # Custom user qubit class converted via to_ast directly. + name: str + + def to_ast(self, program): + return ast.Identifier(self.name) + + q0 = MyQubit(0) + q1 = MyQubit(1) + d = DirectQubit("foo") + + prog = oqpy.Program() + prog.reset(q0) + prog.gate(q0, "h") + prog.gate([q0, q1], "cnot") + prog.gate(q1, "x", control=q0) + prog.gate(q1, "x", control=[q0]) + prog.measure(q0) + prog.reset(d) + prog.gate(d, "h") + + bit = oqpy.BitVar(name="b") + prog.measure(q1, bit) + + with oqpy.defcal(prog, [q0, q1], "mycx"): + prog.gate(q0, "h") + + expected = textwrap.dedent( + """ + OPENQASM 3.0; + bit b; + reset $0; + h $0; + cnot $0, $1; + ctrl @ x $0, $1; + ctrl @ x $0, $1; + measure $0; + reset foo; + h foo; + b = measure $1; + defcal mycx $0, $1 { + h $0; + } + """ + ).strip() + assert prog.to_qasm() == expected + _check_respects_type_hints(prog)