Skip to content
8 changes: 8 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

.. note:: An error will be raided if the loss matrix :math:`\mathbf{M}` contains NaNs.

Uses the algorithm proposed in :ref:`[1] <references-emd>`.

Parameters
Expand Down Expand Up @@ -324,6 +326,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
# convert to numpy
M, a, b = nx.to_numpy(M, a, b)

if np.isnan(M).any():
raise ValueError('The loss matrix should not contain NaN values.')

# ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
Expand Down Expand Up @@ -502,6 +507,9 @@ def emd2(a, b, M, processes=1,
# convert to numpy
M, a, b = nx.to_numpy(M, a, b)

if np.isnan(M).any():
raise ValueError('The loss matrix should not contain NaN values.')

a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64, order='C')
Expand Down
15 changes: 15 additions & 0 deletions test/gromov/test_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,18 @@ def test_fgw_barycenter(nx):

np.testing.assert_allclose(C, Cb, atol=1e-06)
np.testing.assert_allclose(X, Xb, atol=1e-06)


# Related to issue 469
def test_gromov2_nan_in_target_cost():
# GIVEN - a target cost matrix with a NaN value
source_cost = np.zeros((2, 2))
target_cost = np.ones((2, 2))
source_distribution = np.array([0.5, 0.5])
target_distribution = np.array([0.5, 0.5])

target_cost[0, 0] = np.nan

# WHEN - we call
with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'):
ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution)
5 changes: 5 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)


def test_emd_nan_in_loss_matrix():
with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'):
ot.emd([], [], [np.nan])


def test_emd2_multi():
n = 500 # nb bins

Expand Down