Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
<a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a>
[![Downloads](https://static.pepy.tech/badge/torchcfm)](https://pepy.tech/project/torchcfm)
[![Downloads](https://static.pepy.tech/badge/torchcfm/month)](https://pepy.tech/project/torchcfm)

</div>

## Description
Expand Down
65 changes: 65 additions & 0 deletions examples/single_cell/train_single_cell_high_dimension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os

import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import torch
import torchsde
from torchdyn.core import NeuralODE
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.models import MLP
from torchcfm.utils import plot_trajectories, torch_wrapper
from sklearn.preprocessing import StandardScaler

from utils_single_cell import adata_dataset, split, combined_loader, train_dataloader, val_dataloader, test_dataloader

adata = sc.read_h5ad("./ebdata_v2.h5ad")
max_dim=1000

data, labels, ulabels = adata_dataset("./ebdata_v2.h5ad")

if max_dim==1000:
sc.pp.highly_variable_genes(adata, n_top_genes=max_dim)
adata = adata.X[:, adata.var["highly_variable"]].toarray()
Comment on lines +23 to +25
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Potential confusion between 'adata' as AnnData and as a numpy array.

'adata' is reassigned from an AnnData object to a numpy array, which could lead to confusion or errors. Use a different variable name for the numpy array to maintain clarity.


# Standardize coordinates
print(adata.shape)
scaler = StandardScaler()
scaler.fit(adata)
data = scaler.transform(adata)

dim = data.shape[-1]
print("data dim: ", dim)

timepoint_data = [
adata[labels == lab].astype(np.float32) for lab in ulabels
]

split_timepoint_data = split(timepoint_data)

print(f"Loaded ebdata with timepoints {ulabels} of sizes {[len(d) for d in timepoint_data]} with dim {dim}.")


#### TRAINING
train_dataloader = train_dataloader(split_timepoint_data)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 128
sigma = 0.1
ot_cfm_model = MLP(dim=dim, time_varying=True, w=256).to(device)
ot_cfm_optimizer = torch.optim.Adam(ot_cfm_model.parameters(), 1e-4)
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)

n_epochs = 2000
for _ in range(n_epochs):
for X in train_dataloader:
t_snapshot = np.random.randint(0,4)
t, xt, ut = FM.sample_location_and_conditional_flow(X[t_snapshot], X[t_snapshot+1])
Comment on lines +59 to +60
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Hardcoded timepoint range may not generalize.

Using a fixed range with 'np.random.randint(0,4)' risks IndexError if the number of timepoints is less than 5. Please calculate the valid range dynamically from the data.

ot_cfm_optimizer.zero_grad()
vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1))
loss = torch.mean((vt - ut) ** 2)
loss.backward()
ot_cfm_optimizer.step()
49 changes: 49 additions & 0 deletions examples/single_cell/utils_single_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
import scanpy as sc
import torch
from functools import partial
from torch.utils.data import random_split

from pytorch_lightning.trainer.supporters import CombinedLoader
from torch.utils.data import DataLoader

def adata_dataset(path, embed_name="X_pca", label_name="sample_labels", max_dim=100):
adata = sc.read_h5ad(path)
labels = adata.obs[label_name].astype("category")
ulabels = labels.cat.categories
return adata.obsm[embed_name][:, :max_dim], labels, ulabels

def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
if isinstance(train_val_test_split, int):
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
return split_timepoint_data
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
Comment on lines +20 to +27
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (code-quality): Inline variable that is immediately returned [×2] (inline-immediately-returned-variable)

return split_timepoint_data
Comment on lines +16 to +28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Unreachable code path for integer split ratios.

Since 'train_val_test_split' is always a list, this type check is redundant and can be removed to simplify the code.

Suggested change
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
if isinstance(train_val_test_split, int):
split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data))
return split_timepoint_data
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data
def split(timepoint_data):
"""split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels."""
train_val_test_split = [0.8, 0.1, 0.1]
splitter = partial(
random_split,
lengths=train_val_test_split,
generator=torch.Generator().manual_seed(42),
)
split_timepoint_data = list(map(splitter, timepoint_data))
return split_timepoint_data


def combined_loader(split_timepoint_data, index, shuffle=False, load_full=False):
tp_dataloaders = [
DataLoader(
dataset=datasets[index],
batch_size=128,
shuffle=shuffle,
drop_last=True,
)
for datasets in split_timepoint_data
]
Comment on lines +30 to +39
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: Unused 'load_full' parameter in 'combined_loader'.

Consider removing the unused 'load_full' parameter if it serves no purpose.

return CombinedLoader(tp_dataloaders, mode="min_size")

def train_dataloader(split_timepoint_data):
return combined_loader(split_timepoint_data, 0, shuffle=True)

def val_dataloader(split_timepoint_data):
return combined_loader(split_timepoint_data, 1, shuffle=False, load_full=False)

def test_dataloader(split_timepoint_data):
return combined_loader(split_timepoint_data, 2, shuffle=False, load_full=True)
Loading