Skip to content

Commit 9f43898

Browse files
clbonetrflamary
andauthored
[MRG] Slicing UOT (#765)
* 1st try potentials OT 1d * emd1d_dual ok without batch * batched emd1d_dual * 1d potentials with backprop, 1d uot 1st try * up tests 1d solvers * file sliced uot * clip max cdf in wasserstein_1d * Example UOT 1d * normalize weights * add suot * add code example (to test) * tests backend * up code example 1D UOT * Examples UOT 1D * fix output loss uot_1d * Example USOT vs SUOT * Center dual potentials * up tests * up citation * fix backend and skip tf in 1d_dual tests * lint * Default p=2 for UOT 1D * Test UOT1D, refactorize W2 on circle * Typo doc * Typo test sum * Skip test TF * update plot example * Remove icdf mode bc does not work well enough yet * First tests SUOT and USOT + some fix * Docs helper function UOT1D, version jax in backend * Improve doc * More test for SUOT and USOT * Test fix grad jax MacOS * Test fix grad jax MacOS * Test fix grad jax MacOS * Test fix grad jax MacOS * Test fix grad jax MacOS * Test fix grad jax MacOS * Test fix grad jax MacOS * Test fix grad jax MacOS * Test clip weights uot 1d for jax on mac * Test clip weights uot 1d for jax on mac * Fix loss example UOT1D, skip tests jax * Fix loss example UOT1D, skip tests jax * Update example SUOT vs USOT * stable gradient of jnp.abs in 0 * Animation SUOT * review * skip test np, tf, cp * Test grad nx.abs in 0 in jax * Test batched center_ot_dual and cleaning some get_backend * Test quantile_function * Doc power p USOT/SUOT * Fix doc SUOT/USOT * Thumbnail example SUOT * Thumbnail example SUOT --------- Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
1 parent a628dd6 commit 9f43898

22 files changed

Lines changed: 3423 additions & 1064 deletions

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ POT provides the following generic OT solvers:
5454
Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19])
5555
* [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33]
5656
* Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20].
57-
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
57+
* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41]
5858
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations).
5959
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36].
60+
* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [82]
6061
* [Wasserstein distance on the
6162
circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html)
6263
[44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
@@ -367,7 +368,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
367368

368369
[42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021.
369370

370-
[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
371+
[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
371372

372373
[44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258.
373374

@@ -449,5 +450,4 @@ Artificial Intelligence.
449450

450451
[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).
451452

452-
453-
```
453+
[82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research.

RELEASES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
1212
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
1313
- Geomloss function now handles both scalar and slice indices for i and j (PR #785)
1414
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
15+
- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765)
16+
- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765)
1517

1618
#### Closed issues
1719

examples/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ OT in 1D and Sliced Wasserstein
7575
.. minigallery::
7676
../../examples/sliced-wasserstein/plot_*.py
7777
../../examples/plot_compute_emd.py
78+
../../examples/unbalanced-partial/plot_partial_1d.py
79+
../../examples/unbalanced-partial/plot_UOT_1D.py
80+
../../examples/unbalanced-partial/plot_UOT_sliced.py
7881

7982

8083
OT on Gaussian and Gaussian Mixture Models

examples/unbalanced-partial/plot_UOT_1D.py

Lines changed: 69 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
# Author: Hicham Janati <hicham.janati@inria.fr>
12+
# Clément Bonet <clement.bonet.mapp@polytechnique.edu>
1213
#
1314
# License: MIT License
1415

@@ -19,6 +20,7 @@
1920
import ot
2021
import ot.plot
2122
from ot.datasets import make_1D_gauss as gauss
23+
import torch
2224

2325
##############################################################################
2426
# Generate data
@@ -41,7 +43,6 @@
4143

4244
# loss matrix
4345
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
44-
M /= M.max()
4546

4647

4748
##############################################################################
@@ -62,29 +63,88 @@
6263

6364

6465
##############################################################################
65-
# Solve Unbalanced Sinkhorn
66-
# -------------------------
66+
# Solve Unbalanced OT with MM Unbalanced
67+
# -----------------------------------
6768

68-
# Sinkhorn
69+
# %% MM Unbalanced
6970

70-
epsilon = 0.1 # entropy parameter
7171
alpha = 1.0 # Unbalanced KL relaxation parameter
72-
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True)
72+
73+
Gs, log = ot.unbalanced.mm_unbalanced(a, b, M / M.max(), alpha, verbose=False, log=True)
7374

7475
pl.figure(3, figsize=(5, 5))
75-
ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn")
76+
ot.plot.plot1D_mat(a, b, Gs, "UOT plan")
77+
pl.show()
7678

79+
pl.figure(4, figsize=(6.4, 3))
80+
pl.plot(x, a, "b", label="Source distribution")
81+
pl.plot(x, b, "r", label="Target distribution")
82+
pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source")
83+
pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target")
84+
pl.legend(loc="upper right")
85+
pl.title("Distributions and transported mass for UOT")
7786
pl.show()
7887

88+
print("Mass of reweighted marginals:", Gs.sum())
89+
print("Unbalanced OT loss:", log["total_cost"] * M.max())
90+
91+
92+
##############################################################################
93+
# Solve 1D UOT with Frank-Wolfe
94+
# -----------------------------
95+
7996

80-
# %%
81-
# plot the transported mass
97+
# %% 1D UOT with FW
98+
99+
100+
alpha = M.max() # Unbalanced KL relaxation parameter
101+
102+
a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d(
103+
torch.tensor(x, dtype=torch.float64),
104+
torch.tensor(x, dtype=torch.float64),
105+
alpha,
106+
u_weights=torch.tensor(a, dtype=torch.float64),
107+
v_weights=torch.tensor(b, dtype=torch.float64),
108+
p=2,
109+
returnCost="total",
110+
)
111+
112+
pl.figure(4, figsize=(6.4, 3))
113+
pl.plot(x, a, "b", label="Source distribution")
114+
pl.plot(x, b, "r", label="Target distribution")
115+
pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source")
116+
pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target")
117+
pl.legend(loc="upper right")
118+
pl.title("Distributions and transported mass for UOT")
119+
pl.show()
120+
121+
print("Mass of reweighted marginals:", a_reweighted.sum().item())
122+
print("Unbalanced OT loss:", loss.item())
123+
124+
125+
##############################################################################
126+
# Solve Unbalanced Sinkhorn
82127
# -------------------------
83128

129+
# %% Sinkhorn UOT
130+
131+
# Sinkhorn
132+
133+
epsilon = 0.1 # entropy parameter
134+
alpha = 1.0 # Unbalanced KL relaxation parameter
135+
Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M / M.max(), epsilon, alpha, verbose=True)
136+
137+
pl.figure(3, figsize=(5, 5))
138+
ot.plot.plot1D_mat(a, b, Gs, "Entropic UOT plan")
139+
pl.show()
140+
84141
pl.figure(4, figsize=(6.4, 3))
85142
pl.plot(x, a, "b", label="Source distribution")
86143
pl.plot(x, b, "r", label="Target distribution")
87144
pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source")
88145
pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target")
89146
pl.legend(loc="upper right")
90147
pl.title("Distributions and transported mass for UOT")
148+
pl.show()
149+
150+
print("Mass of reweighted marginals:", Gs.sum())

0 commit comments

Comments
 (0)