Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
fd7e2b4
Sienna & Thibaut works
thibaut-germain Feb 9, 2026
9b49c76
add sgot file
thibaut-germain Feb 9, 2026
5e775a5
Merge branch 'master' into sgot
rflamary Feb 11, 2026
ba33032
first draft of sgot.py
osheasienna Feb 15, 2026
2eaf221
Merge branch 'master' into sgot
rflamary Feb 16, 2026
1c73452
Merge branch 'master' into sgot
rflamary Feb 17, 2026
3f10111
rewrite backend and refactor OT metric
osheasienna Feb 20, 2026
d5fef5b
refactor sgot, and implement tests for sgot and backend
osheasienna Feb 22, 2026
c61c398
Merge branch 'master' into sgot
rflamary Feb 23, 2026
630e359
fix astype in backend
osheasienna Feb 23, 2026
6b4d17b
Merge branch 'master' into sgot
rflamary Feb 23, 2026
055c5e4
Merge branch 'master' into sgot
rflamary Feb 24, 2026
57c7a0c
edits as per PR #792
osheasienna Feb 25, 2026
a0e74be
cost & metric fixed in test_sgot
osheasienna Feb 27, 2026
e994303
correct issues on test_sgot
osheasienna Feb 27, 2026
e9f8be7
fixing test failures
osheasienna Feb 27, 2026
37aa721
sgot documentation
osheasienna Mar 14, 2026
c0c7acf
Merge branch 'master' into sgot
rflamary Mar 19, 2026
7641f7f
fix backend error and example notebook
osheasienna Mar 31, 2026
b8a6cbe
removed notebook, added python version
osheasienna Mar 31, 2026
0b90ab7
fix SGOT example and relax SGOT test tolerances
osheasienna Mar 31, 2026
9cffb45
gaussian metric on graphs
osheasienna Apr 8, 2026
dae7b75
updated rotation paragraph across graphs
osheasienna Apr 8, 2026
579cce4
equation colours
osheasienna Apr 8, 2026
2882f32
fixed sgot distance vs frequency (omega)
osheasienna Apr 8, 2026
9b3cdc7
fix RELEASES
osheasienna Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Releases


## Upcomming 0.9.7.post1

#### New features
The next release will add cost functions between linear operators following [A Spectral-Grassmann Wasserstein metric for operator representations of dynamical systems](https://arxiv.org/pdf/2509.24920).
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this text to the new feature of 0.9.7.dev0 this is what we are working on. Also add a line in the Itemize with the PR number




## 0.9.7.dev0

This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.
Expand Down
152 changes: 152 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,55 @@ def clip(self, a, a_min, a_max):
"""
raise NotImplementedError()

def real(self, a):
"""
Return the real part of the tensor element-wise.

This function follows the api from :any:`numpy.real`

See: https://numpy.org/doc/stable/reference/generated/numpy.real.html
"""
raise NotImplementedError()

def imag(self, a):
"""
Return the imaginary part of the tensor element-wise.

This function follows the api from :any:`numpy.imag`

See: https://numpy.org/doc/stable/reference/generated/numpy.imag.html
"""
raise NotImplementedError()

def conj(self, a):
"""
Return the complex conjugate, element-wise.

This function follows the api from :any:`numpy.conj`

See: https://numpy.org/doc/stable/reference/generated/numpy.conj.html
"""
raise NotImplementedError()

def arccos(self, a):
"""
Trigonometric inverse cosine, element-wise.

This function follows the api from :any:`numpy.arccos`

See: https://numpy.org/doc/stable/reference/generated/numpy.arccos.html
"""
raise NotImplementedError()

def astype(self, a, dtype):
"""
Cast tensor to a given dtype.

dtype can be a string (e.g. "complex128", "float64") or backend-specific
dtype. Backend converts to the corresponding type.
"""
raise NotImplementedError()

def repeat(self, a, repeats, axis=None):
r"""
Repeats elements of a tensor.
Expand Down Expand Up @@ -1294,6 +1343,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return np.clip(a, a_min, a_max)

def real(self, a):
return np.real(a)

def imag(self, a):
return np.imag(a)

def conj(self, a):
return np.conj(a)

def arccos(self, a):
return np.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(np, dtype, None) or np.dtype(dtype)
return np.asarray(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return np.repeat(a, repeats, axis)

Expand Down Expand Up @@ -1711,6 +1777,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return jnp.clip(a, a_min, a_max)

def real(self, a):
return jnp.real(a)

def imag(self, a):
return jnp.imag(a)

def conj(self, a):
return jnp.conj(a)

def arccos(self, a):
return jnp.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(jnp, dtype, None) or jnp.dtype(dtype)
return jnp.asarray(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return jnp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -2208,6 +2291,41 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return torch.clamp(a, a_min, a_max)

def real(self, a):
return torch.real(a)

def imag(self, a):
return torch.imag(a)

def conj(self, a):
return torch.conj(a)

def arccos(self, a):
return torch.acos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
# Map common numpy-style string dtypes to torch dtypes explicitly.
# This makes backend.astype robust across torch versions and aliases.
mapping = {
"float32": torch.float32,
"float64": torch.float64,
"float": torch.float32,
"double": torch.float64,
"complex64": getattr(torch, "complex64", None),
"complex128": getattr(torch, "complex128", None),
}
torch_dtype = mapping.get(dtype)
if torch_dtype is None:
# Fallback: try direct attribute lookup (e.g. torch.float16)
torch_dtype = getattr(torch, dtype, None)
if torch_dtype is None:
raise ValueError(
f"Unsupported dtype for TorchBackend.astype: {dtype!r}"
)
dtype = torch_dtype
return a.to(dtype=dtype)

def repeat(self, a, repeats, axis=None):
return torch.repeat_interleave(a, repeats, dim=axis)

Expand Down Expand Up @@ -2709,6 +2827,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return cp.clip(a, a_min, a_max)

def real(self, a):
return cp.real(a)

def imag(self, a):
return cp.imag(a)

def conj(self, a):
return cp.conj(a)

def arccos(self, a):
return cp.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(cp, dtype, None) or cp.dtype(dtype)
return cp.asarray(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return cp.repeat(a, repeats, axis)

Expand Down Expand Up @@ -3143,6 +3278,23 @@ def outer(self, a, b):
def clip(self, a, a_min, a_max):
return tnp.clip(a, a_min, a_max)

def real(self, a):
return tnp.real(a)

def imag(self, a):
return tnp.imag(a)

def conj(self, a):
return tnp.conj(a)

def arccos(self, a):
return tnp.arccos(a)

def astype(self, a, dtype):
if isinstance(dtype, str):
dtype = getattr(tnp, dtype, None) or tnp.dtype(dtype)
return tnp.array(a, dtype=dtype)

def repeat(self, a, repeats, axis=None):
return tnp.repeat(a, repeats, axis)

Expand Down
Loading
Loading