Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ class CircuitToPauliStringsMeasurementResult:
results: list[PauliStringMeasurementResult]


@attrs.frozen
class TRexMetadata:
Comment thread
ddddddanni marked this conversation as resolved.
"""Metadata required to compute T-REX mitigated expectation values later."""
Comment thread
ddddddanni marked this conversation as resolved.
Outdated

# The Pauli string that is being measured.
pauli_str: ops.PauliString
# A 2D boolean array of shape (num_twirls, num_qubits) indicating
# the random twirl choices
twirl_choices: np.ndarray
# A 2D boolean array of shape (num_readout_circuits, num_qubits) indicating
# the random choices for generating readout calibration circuits.
readout_choices: np.ndarray


def _flatten_pauli_objs(
pauli_objs: Sequence[ops.PauliString | ops.PauliSum],
) -> Iterator[ops.PauliString]:
Expand Down Expand Up @@ -389,6 +403,77 @@ def _extract_readout_qubits(
return sorted(all_qubits)


def _get_trex_twirled_basis_gate(basis: ops.Pauli, flip: bool) -> cirq.Gate:
if basis == ops.Z and not flip:
return ops.I
elif basis == ops.Z and flip:
return ops.X
elif basis == ops.X and not flip:
return ops.ry(rads=-np.pi / 2)
elif basis == ops.X and flip:
return ops.ry(rads=np.pi / 2)
elif basis == ops.Y and not flip:
return ops.rx(rads=np.pi / 2)
elif basis == ops.Y and flip:
return ops.rx(rads=-np.pi / 2)
else:
Comment thread
ddddddanni marked this conversation as resolved.
raise ValueError(f"Unsupported basis: {basis}. Expected X, Y, Z, or I.") # pragma: no cover


def _generate_random_boolean_choices(
num_choices: int, num_qubits: int, rng: np.random.Generator
) -> np.ndarray:
"""Generates a 2D boolean array for random circuit choices.

Args:
num_choices: The number of circuits/rows to generate choices for.
num_qubits: The number of qubits.
rng: A numpy random generator.

Returns:
A boolean array of shape (num_choices, num_qubits).
"""
return rng.integers(0, 2, (num_choices, num_qubits), dtype=bool) # pragma: no cover


def _build_trex_twirled_pauli_circuits(
base_circuit: circuits.Circuit,
Comment thread
ddddddanni marked this conversation as resolved.
basis_ps: ops.PauliString,
twirl_choices: np.ndarray,
insert_strategy: circuits.InsertStrategy = circuits.InsertStrategy.INLINE,
) -> list[circuits.Circuit]:
"""Builds a list of twirled circuits for measuring the given Pauli strings.
Comment thread
ddddddanni marked this conversation as resolved.

Args:
base_circuit: The original circuit to be twirled.
basis_ps: A PauliString representing the target measurement basis for each qubit.
twirl_choices: A 2D boolean array of shape (num_twirls, len(qubits)) indicating
whether to apply a 180-degree twirl to each qubit.
Comment thread
ddddddanni marked this conversation as resolved.
Outdated
insert_strategy: The strategy for inserting twirling gates and measurements into
the base circuit.

Returns:
A list of twirled circuits.
"""
twirl_circuits = []
qubits = sorted(basis_ps.qubits)

for twirl_choice in twirl_choices:
# Map each qubit to its specific twirl operation for this shot
moment_ops = [
_get_trex_twirled_basis_gate(basis_ps[q], flip)(q)
for flip, q in zip(twirl_choice, qubits)
]
Comment thread
ddddddanni marked this conversation as resolved.

# Append the twirls and the final measurement to the base circuit
twirled_circuit = circuits.Circuit(
base_circuit, moment_ops, ops.M(*qubits, key='m'), strategy=insert_strategy
Comment thread
ddddddanni marked this conversation as resolved.
Outdated
)
twirl_circuits.append(twirled_circuit)

return twirl_circuits
Comment thread
ddddddanni marked this conversation as resolved.


def _pauli_objs_to_basis_change_ops(
pauli_objs: Sequence[ops.PauliString | ops.PauliSum], qid_list: Sequence[ops.Qid]
):
Expand Down Expand Up @@ -1058,6 +1143,34 @@ def _measure_pauli_strings_with_confusion_matrices(
return results


def generate_trex_and_readout_circuits(
circuit_to_pauli: CircuitToPauliStringsParameters,
num_twirls: int,
num_readout_circuits: int,
rng: np.random.Generator,
) -> tuple[list[circuits.Circuit], TRexMetadata]:
"""Generates a list of circuits for TREX benchmarking and readout calibration.

This function generates `num_twirls` circuits by applying random Pauli twirls
to the input circuit. It also generates `num_readout_circuits` for readout
error calibration. Each circuit is appended with measurement operations on the
targeted qubits.

Args:
circuit_to_pauli: A CircuitToPauliStringsParameters object containing the original
circuit and its associated Pauli strings.
num_twirls: The number of twirled circuits to generate for each original circuit.
num_readout_circuits: The number of readout calibration circuits to generate.
rng: A NumPy random number generator for generating random Pauli twirls.

Returns:
A tuple containing:
- A combined list of the twirled Pauli circuits followed by the readout circuits.
- A TRexMetadata object containing the random choices needed for post-processing.
"""
raise NotImplementedError("T-REX error mitigation is not yet implemented.")


def measure_pauli_strings(
circuits_to_pauli: (
Mapping[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
import cirq
from cirq.contrib.paulistring import CircuitToPauliStringsParameters, measure_pauli_strings
from cirq.contrib.paulistring.pauli_string_measurement_with_readout_mitigation import (
_build_trex_twirled_pauli_circuits,
PostFilteringSymmetryCalibrationResult as PostFilteringSymmetryCalibrationResult,
TRexMetadata,
)
from cirq.experiments import SingleQubitReadoutCalibrationResult
from cirq.experiments.single_qubit_readout_calibration_test import NoisySingleQubitReadoutSampler
Expand Down Expand Up @@ -935,3 +937,73 @@ def test_sampler_receives_correct_circuits(use_sweep: bool) -> None:
for q in op.qubits
}
assert measured == expected_qubits


def test_build_trex_twirled_pauli_circuits_multiple_twirls():
"""Test generating multiple circuits from a multi-row twirl_choices array."""
q0, q1, q2 = cirq.LineQubit.range(3)
base_circuit = cirq.Circuit(cirq.H(q0), cirq.CNOT(q0, q1), cirq.CNOT(q0, q2))
basis_ps = cirq.X(q0) * cirq.Y(q1) * cirq.Z(q2)

# 3 different twirl choices
twirl_choices = np.array(
[
[False, True, True], # q0(no flip), q1(flip), q2(flip)]
[True, False, False], # q0(flip), q1(no flip), q2(no flip)]
[False, False, False], # q0(no flip), q1(no flip), q2(no flip)]
]
)

circuits = _build_trex_twirled_pauli_circuits(base_circuit, basis_ps, twirl_choices)

assert len(circuits) == 3

q0_no_flip = cirq.Ry(rads=-np.pi / 2)(q0)
q0_flip = cirq.Ry(rads=np.pi / 2)(q0)

q1_no_flip = cirq.Rx(rads=np.pi / 2)(q1)
q1_flip = cirq.Rx(rads=-np.pi / 2)(q1)

q2_no_flip = cirq.I(q2)
q2_flip = cirq.X(q2)

# Verify Circuit 0: row [False, True, True]
assert q0_no_flip in circuits[0].moments[-2].operations
assert q1_flip in circuits[0].moments[-2].operations
assert q2_flip in circuits[0].moments[-2].operations

# Verify Circuit 1: row [True, False, False]
assert q0_flip in circuits[1].moments[-2].operations
assert q1_no_flip in circuits[1].moments[-2].operations
assert q2_no_flip in circuits[1].moments[-2].operations

# Verify Circuit 2: row [False, False, False]
assert q0_no_flip in circuits[2].moments[-2].operations
assert q1_no_flip in circuits[2].moments[-2].operations
assert q2_no_flip in circuits[2].moments[-2].operations

# Verify that every generated circuit ends with the correct joint measurement
for circuit in circuits:
meas_op = circuit.moments[-1].operations[0]
assert isinstance(meas_op.gate, cirq.MeasurementGate)
assert meas_op.qubits == (q0, q1, q2)
assert meas_op.gate.key == 'm'


def test_trex_metadata_instantiation() -> None:
"""Test the instantiation and attributes of TRexMetadata."""
q0, q1 = cirq.LineQubit.range(2)
pauli_str = cirq.X(q0) * cirq.Z(q1)

# 2D boolean arrays of shape (num_readout_circuits, num_qubits)
twirl_choices = np.array([[True, False], [False, True], [True, True]])

readout_choices = np.array([[False, False], [True, True], [False, True]])

metadata = TRexMetadata(
pauli_str=pauli_str, twirl_choices=twirl_choices, readout_choices=readout_choices
)

assert metadata.pauli_str == pauli_str
np.testing.assert_array_equal(metadata.twirl_choices, twirl_choices)
np.testing.assert_array_equal(metadata.readout_choices, readout_choices)
Loading