-
Notifications
You must be signed in to change notification settings - Fork 173
Single cell hd #169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Single cell hd #169
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| import os | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| import scanpy as sc | ||
| import torch | ||
| import torchsde | ||
| from torchdyn.core import NeuralODE | ||
| from tqdm import tqdm | ||
|
|
||
| from torchcfm.conditional_flow_matching import * | ||
| from torchcfm.models import MLP | ||
| from torchcfm.utils import plot_trajectories, torch_wrapper | ||
| from sklearn.preprocessing import StandardScaler | ||
|
|
||
| from utils_single_cell import adata_dataset, split, combined_loader, train_dataloader, val_dataloader, test_dataloader | ||
|
|
||
| adata = sc.read_h5ad("./ebdata_v2.h5ad") | ||
| max_dim=1000 | ||
|
|
||
| data, labels, ulabels = adata_dataset("./ebdata_v2.h5ad") | ||
|
|
||
| if max_dim==1000: | ||
| sc.pp.highly_variable_genes(adata, n_top_genes=max_dim) | ||
| adata = adata.X[:, adata.var["highly_variable"]].toarray() | ||
|
|
||
| # Standardize coordinates | ||
| print(adata.shape) | ||
| scaler = StandardScaler() | ||
| scaler.fit(adata) | ||
| data = scaler.transform(adata) | ||
|
|
||
| dim = data.shape[-1] | ||
| print("data dim: ", dim) | ||
|
|
||
| timepoint_data = [ | ||
| adata[labels == lab].astype(np.float32) for lab in ulabels | ||
| ] | ||
|
|
||
| split_timepoint_data = split(timepoint_data) | ||
|
|
||
| print(f"Loaded ebdata with timepoints {ulabels} of sizes {[len(d) for d in timepoint_data]} with dim {dim}.") | ||
|
|
||
|
|
||
| #### TRAINING | ||
| train_dataloader = train_dataloader(split_timepoint_data) | ||
|
|
||
| use_cuda = torch.cuda.is_available() | ||
| device = torch.device("cuda" if use_cuda else "cpu") | ||
| batch_size = 128 | ||
| sigma = 0.1 | ||
| ot_cfm_model = MLP(dim=dim, time_varying=True, w=256).to(device) | ||
| ot_cfm_optimizer = torch.optim.Adam(ot_cfm_model.parameters(), 1e-4) | ||
| FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma) | ||
|
|
||
| n_epochs = 2000 | ||
| for _ in range(n_epochs): | ||
| for X in train_dataloader: | ||
| t_snapshot = np.random.randint(0,4) | ||
| t, xt, ut = FM.sample_location_and_conditional_flow(X[t_snapshot], X[t_snapshot+1]) | ||
|
Comment on lines
+59
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Hardcoded timepoint range may not generalize. Using a fixed range with 'np.random.randint(0,4)' risks IndexError if the number of timepoints is less than 5. Please calculate the valid range dynamically from the data. |
||
| ot_cfm_optimizer.zero_grad() | ||
| vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1)) | ||
| loss = torch.mean((vt - ut) ** 2) | ||
| loss.backward() | ||
| ot_cfm_optimizer.step() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,49 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||
| import scanpy as sc | ||||||||||||||||||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||||||||||||||||||
| from functools import partial | ||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data import random_split | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| from pytorch_lightning.trainer.supporters import CombinedLoader | ||||||||||||||||||||||||||||||||||||||||||||||||
| from torch.utils.data import DataLoader | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def adata_dataset(path, embed_name="X_pca", label_name="sample_labels", max_dim=100): | ||||||||||||||||||||||||||||||||||||||||||||||||
| adata = sc.read_h5ad(path) | ||||||||||||||||||||||||||||||||||||||||||||||||
| labels = adata.obs[label_name].astype("category") | ||||||||||||||||||||||||||||||||||||||||||||||||
| ulabels = labels.cat.categories | ||||||||||||||||||||||||||||||||||||||||||||||||
| return adata.obsm[embed_name][:, :max_dim], labels, ulabels | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def split(timepoint_data): | ||||||||||||||||||||||||||||||||||||||||||||||||
| """split requires self.hparams.train_val_test_split, timepoint_data, system, ulabels.""" | ||||||||||||||||||||||||||||||||||||||||||||||||
| train_val_test_split = [0.8, 0.1, 0.1] | ||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(train_val_test_split, int): | ||||||||||||||||||||||||||||||||||||||||||||||||
| split_timepoint_data = list(map(lambda x: (x, x, x), timepoint_data)) | ||||||||||||||||||||||||||||||||||||||||||||||||
| return split_timepoint_data | ||||||||||||||||||||||||||||||||||||||||||||||||
| splitter = partial( | ||||||||||||||||||||||||||||||||||||||||||||||||
| random_split, | ||||||||||||||||||||||||||||||||||||||||||||||||
| lengths=train_val_test_split, | ||||||||||||||||||||||||||||||||||||||||||||||||
| generator=torch.Generator().manual_seed(42), | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| split_timepoint_data = list(map(splitter, timepoint_data)) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+20
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (code-quality): Inline variable that is immediately returned [×2] ( |
||||||||||||||||||||||||||||||||||||||||||||||||
| return split_timepoint_data | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+16
to
+28
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Unreachable code path for integer split ratios. Since 'train_val_test_split' is always a list, this type check is redundant and can be removed to simplify the code.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def combined_loader(split_timepoint_data, index, shuffle=False, load_full=False): | ||||||||||||||||||||||||||||||||||||||||||||||||
| tp_dataloaders = [ | ||||||||||||||||||||||||||||||||||||||||||||||||
| DataLoader( | ||||||||||||||||||||||||||||||||||||||||||||||||
| dataset=datasets[index], | ||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size=128, | ||||||||||||||||||||||||||||||||||||||||||||||||
| shuffle=shuffle, | ||||||||||||||||||||||||||||||||||||||||||||||||
| drop_last=True, | ||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||
| for datasets in split_timepoint_data | ||||||||||||||||||||||||||||||||||||||||||||||||
| ] | ||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+30
to
+39
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: Unused 'load_full' parameter in 'combined_loader'. Consider removing the unused 'load_full' parameter if it serves no purpose. |
||||||||||||||||||||||||||||||||||||||||||||||||
| return CombinedLoader(tp_dataloaders, mode="min_size") | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def train_dataloader(split_timepoint_data): | ||||||||||||||||||||||||||||||||||||||||||||||||
| return combined_loader(split_timepoint_data, 0, shuffle=True) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def val_dataloader(split_timepoint_data): | ||||||||||||||||||||||||||||||||||||||||||||||||
| return combined_loader(split_timepoint_data, 1, shuffle=False, load_full=False) | ||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||
| def test_dataloader(split_timepoint_data): | ||||||||||||||||||||||||||||||||||||||||||||||||
| return combined_loader(split_timepoint_data, 2, shuffle=False, load_full=True) | ||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Potential confusion between 'adata' as AnnData and as a numpy array.
'adata' is reassigned from an AnnData object to a numpy array, which could lead to confusion or errors. Use a different variable name for the numpy array to maintain clarity.