|
44 | 44 | from bqskit.ir.region import CircuitRegion |
45 | 45 | from bqskit.ir.region import CircuitRegionLike |
46 | 46 | from bqskit.qis.graph import CouplingGraph |
47 | | -from bqskit.qis.permutation import PermutationMatrix |
48 | 47 | from bqskit.qis.state.state import StateLike |
49 | 48 | from bqskit.qis.state.state import StateVector |
50 | 49 | from bqskit.qis.state.statemap import StateVectorMap |
@@ -2578,29 +2577,19 @@ def get_unitary_and_grad( |
2578 | 2577 | # Calculate gradient |
2579 | 2578 | left = UnitaryBuilder(self.num_qudits, self.radixes) |
2580 | 2579 | right = UnitaryBuilder(self.num_qudits, self.radixes) |
2581 | | - full_gards = [] |
| 2580 | + full_grads = [] |
2582 | 2581 |
|
2583 | 2582 | for M, loc in zip(matrices, locations): |
2584 | 2583 | right.apply_right(M, loc) |
2585 | 2584 |
|
2586 | 2585 | for M, dM, loc in zip(matrices, grads, locations): |
2587 | | - perm = PermutationMatrix.from_qubit_location(self.num_qudits, loc) |
2588 | | - permT = perm.T |
2589 | | - iden = np.identity(2 ** (self.num_qudits - len(loc))) |
2590 | | - |
2591 | 2586 | right.apply_left(M, loc, inverse=True) |
2592 | 2587 | right_utry = right.get_unitary() |
2593 | | - left_utry = left.get_unitary() |
2594 | 2588 | for grad in dM: |
2595 | | - # TODO: use tensor contractions here instead of mm |
2596 | | - # Should work fine with non unitary gradients |
2597 | | - # TODO: Fix for non qubits |
2598 | | - full_grad = np.kron(grad, iden) |
2599 | | - full_grad = permT @ full_grad @ perm |
2600 | | - full_gards.append(right_utry @ full_grad @ left_utry) |
| 2589 | + full_grads.append(right_utry @ left.eval_apply_right(grad, loc)) |
2601 | 2590 | left.apply_right(M, loc) |
2602 | 2591 |
|
2603 | | - return left.get_unitary(), np.array(full_gards) |
| 2592 | + return left.get_unitary(), np.array(full_grads) |
2604 | 2593 |
|
2605 | 2594 | def instantiate( |
2606 | 2595 | self, |
|
0 commit comments