diff --git a/simple_sfm/utils/geometry.py b/simple_sfm/utils/geometry.py index 719f76e..b9f5710 100644 --- a/simple_sfm/utils/geometry.py +++ b/simple_sfm/utils/geometry.py @@ -354,7 +354,7 @@ def average_quaternions(quaternions: torch.Tensor, # B x 4 X 4 mat_a = torch.einsum('bni,bnj->bij', [quaternions * weights, quaternions]) # compute eigenvalues and -vectors - _, eigen_vectors = torch.symeig(mat_a, eigenvectors=True) + _, eigen_vectors = torch.linalg.eigh(mat_a) max_value_eigenvector = eigen_vectors[..., -1].reshape(*batch_shape, 4) return max_value_eigenvector.unsqueeze(-2) if keepdim else max_value_eigenvector