Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import warnings
from scipy.sparse import issparse as scipy_issparse

from ..utils import list_to_array, check_number_threads
from ..backend import get_backend
Expand Down Expand Up @@ -298,10 +299,20 @@ def emd(
a, b = list_to_array(a, b)
nx = get_backend(a, b)

# Check if M is sparse using backend's issparse method
is_sparse = nx.issparse(M)
# Check if M is sparse (either backend sparse or scipy.sparse)
is_sparse = nx.issparse(M) or scipy_issparse(M)

if is_sparse:
# Check if backend supports sparse matrices
backend_name = nx.__class__.__name__
if backend_name in ["JaxBackend", "TensorflowBackend"]:
raise NotImplementedError(
f"Sparse optimal transport is not supported for {backend_name}. "
"JAX does not have native sparse matrix support, and TensorFlow's "
"sparse implementation is incomplete. Please convert your sparse "
"matrix to dense format using M.toarray() or equivalent before calling emd()."
)

# Extract COO data using backend method - returns numpy arrays
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)

Expand Down Expand Up @@ -572,10 +583,22 @@ def emd2(
a, b = list_to_array(a, b)
nx = get_backend(a, b)

# Check if M is sparse using backend's issparse method
is_sparse = nx.issparse(M)
# Check if M is sparse (either backend sparse or scipy.sparse)
from scipy.sparse import issparse as scipy_issparse

is_sparse = nx.issparse(M) or scipy_issparse(M)

if is_sparse:
# Check if backend supports sparse matrices
backend_name = nx.__class__.__name__
if backend_name in ["JaxBackend", "TensorflowBackend"]:
raise NotImplementedError(
f"Sparse optimal transport is not supported for {backend_name}. "
"JAX does not have native sparse matrix support, and TensorFlow's "
"sparse implementation is incomplete. Please convert your sparse "
"matrix to dense format using M.toarray() or equivalent before calling emd2()."
)

# Extract COO data using backend method - returns numpy arrays
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)

Expand Down