Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fd7e2b4
Sienna & Thibaut works
thibaut-germain Feb 9, 2026
9b49c76
add sgot file
thibaut-germain Feb 9, 2026
5e775a5
Merge branch 'master' into sgot
rflamary Feb 11, 2026
ba33032
first draft of sgot.py
osheasienna Feb 15, 2026
2eaf221
Merge branch 'master' into sgot
rflamary Feb 16, 2026
1c73452
Merge branch 'master' into sgot
rflamary Feb 17, 2026
3f10111
rewrite backend and refactor OT metric
osheasienna Feb 20, 2026
d5fef5b
refactor sgot, and implement tests for sgot and backend
osheasienna Feb 22, 2026
c61c398
Merge branch 'master' into sgot
rflamary Feb 23, 2026
630e359
fix astype in backend
osheasienna Feb 23, 2026
6b4d17b
Merge branch 'master' into sgot
rflamary Feb 23, 2026
055c5e4
Merge branch 'master' into sgot
rflamary Feb 24, 2026
57c7a0c
edits as per PR #792
osheasienna Feb 25, 2026
a0e74be
cost & metric fixed in test_sgot
osheasienna Feb 27, 2026
e994303
correct issues on test_sgot
osheasienna Feb 27, 2026
e9f8be7
fixing test failures
osheasienna Feb 27, 2026
37aa721
sgot documentation
osheasienna Mar 14, 2026
c0c7acf
Merge branch 'master' into sgot
rflamary Mar 19, 2026
7641f7f
fix backend error and example notebook
osheasienna Mar 31, 2026
b8a6cbe
removed notebook, added python version
osheasienna Mar 31, 2026
0b90ab7
fix SGOT example and relax SGOT test tolerances
osheasienna Mar 31, 2026
9cffb45
gaussian metric on graphs
osheasienna Apr 8, 2026
dae7b75
updated rotation paragraph across graphs
osheasienna Apr 8, 2026
579cce4
equation colours
osheasienna Apr 8, 2026
2882f32
fixed sgot distance vs frequency (omega)
osheasienna Apr 8, 2026
9b3cdc7
fix RELEASES
osheasienna Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Releases


## 0.9.7.dev0

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.
Expand All @@ -12,8 +13,13 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
<<<<<<< HEAD
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
=======
- Add cost functions between linear operators following
[A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920) (PR #792)
>>>>>>> 8d13c55 (edits as per PR #792)

#### Closed issues

Expand Down
110 changes: 106 additions & 4 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,46 @@ def clip(self, a, a_min=None, a_max=None):
"""
raise NotImplementedError()

def real(self, a):
"""
Return the real part of the tensor element-wise.

This function follows the api from :any:`numpy.real`

See: https://numpy.org/doc/stable/reference/generated/numpy.real.html
"""
raise NotImplementedError()

def imag(self, a):
"""
Return the imaginary part of the tensor element-wise.

This function follows the api from :any:`numpy.imag`

See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html
"""
raise NotImplementedError()

def conj(self, a):
"""
Return the complex conjugate, element-wise.

This function follows the api from :any:`numpy.conj`

See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html
"""
raise NotImplementedError()

def arccos(self, a):
"""
Trigonometric inverse cosine, element-wise.

This function follows the api from :any:`numpy.arccos`

See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html
"""
raise NotImplementedError()

def repeat(self, a, repeats, axis=None):
r"""
Repeats elements of a tensor.
Expand Down Expand Up @@ -1193,7 +1233,7 @@ def _from_numpy(self, a, type_as=None):
elif isinstance(a, float):
return a
else:
return a.astype(type_as.dtype)
return np.asarray(a, dtype=type_as.dtype)

def set_gradients(self, val, inputs, grads):
# No gradients for numpy
Expand Down Expand Up @@ -1313,6 +1353,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return np.clip(a, a_min, a_max)

def real(self, a):
return np.real(a)

def imag(self, a):
return np.imag(a)

def conj(self, a):
return np.conj(a)

def arccos(self, a):
return np.arccos(a)

def repeat(self, a, repeats, axis=None):
return np.repeat(a, repeats, axis)

Expand Down Expand Up @@ -1604,7 +1656,7 @@ def _from_numpy(self, a, type_as=None):
if type_as is None:
return jnp.array(a)
else:
return self._change_device(jnp.array(a).astype(type_as.dtype), type_as)
return self._change_device(jnp.asarray(a, dtype=type_as.dtype), type_as)

def set_gradients(self, val, inputs, grads):
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -1730,6 +1782,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return jnp.clip(a, a_min, a_max)

def real(self, a):
return jnp.real(a)

def imag(self, a):
return jnp.imag(a)

def conj(self, a):
return jnp.conj(a)

def arccos(self, a):
return jnp.arccos(a)

def repeat(self, a, repeats, axis=None):
return jnp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -1803,7 +1867,9 @@ def randperm(self, size, type_as=None):
if not isinstance(size, int):
raise ValueError("size must be an integer")
if type_as is not None:
return jax.random.permutation(subkey, size).astype(type_as.dtype)
return jnp.asarray(
jax.random.permutation(subkey, size), dtype=type_as.dtype
)
else:
return jax.random.permutation(subkey, size)

Expand Down Expand Up @@ -2227,6 +2293,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return torch.clamp(a, a_min, a_max)

def real(self, a):
return torch.real(a)

def imag(self, a):
return torch.imag(a)

def conj(self, a):
return torch.conj(a)

def arccos(self, a):
return torch.acos(a)

def repeat(self, a, repeats, axis=None):
return torch.repeat_interleave(a, repeats, dim=axis)

Expand Down Expand Up @@ -2728,6 +2806,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return cp.clip(a, a_min, a_max)

def real(self, a):
return cp.real(a)

def imag(self, a):
return cp.imag(a)

def conj(self, a):
return cp.conj(a)

def arccos(self, a):
return cp.arccos(a)

def repeat(self, a, repeats, axis=None):
return cp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -2819,7 +2909,7 @@ def randperm(self, size, type_as=None):
return self.rng_.permutation(size)
else:
with cp.cuda.Device(type_as.device):
return self.rng_.permutation(size).astype(type_as.dtype)
return cp.asarray(self.rng_.permutation(size), dtype=type_as.dtype)

def coo_matrix(self, data, rows, cols, shape=None, type_as=None):
data = self.from_numpy(data)
Expand Down Expand Up @@ -3162,6 +3252,18 @@ def outer(self, a, b):
def clip(self, a, a_min=None, a_max=None):
return tnp.clip(a, a_min, a_max)

def real(self, a):
return tnp.real(a)

def imag(self, a):
return tnp.imag(a)

def conj(self, a):
return tnp.conj(a)

def arccos(self, a):
return tnp.arccos(a)

def repeat(self, a, repeats, axis=None):
return tnp.repeat(a, repeats, axis)

Expand Down
Loading
Loading