-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
I'm doing a sequence modelling tasks, and my dataset is consisted of many length variant samples. To avoid too many EOS padding, I implemented a sampler to allocate samples with similar lengths into a batch. Nevertheless, I found when I resume the ckpts, the loss will increase dramatically. ( I have checked that this error is not caused by precision, save/load-related problems). What's more, this error only occurs in multi-GPU experiments.
As for implementation details, I noted the lightning fabric will reinstantialize the sampler. To correctly pass the parameters the sampler needs, I pass the parameters through a packed kwargs and let my sampler inherited from torch.utils.data.BatchSampler. please see my codes below.
I do not know whether my code is wrong or there's some hidden bugs in fabric, I will appreciate if you can help me.
What version are you seeing the problem on?
v2.3
Reproduced in studio
No response
How to reproduce the bug
import torch
class SmartSampler(BatchSampler):
"""
Smart sampler that groups sequences by length to minimize padding.
This sampler reduces the amount of EOS tokens needed per batch by grouping
sequences of similar lengths together, significantly improving training efficiency.
"""
def __init__(self, sampler: Sampler, batch_size: int = 8, drop_last: bool = False, **kwargs):
"""
Args:
sampler: Base sampler to wrap (can be DistributedSamplerWrapper in multi-GPU)
batch_size: Number of samples per batch
drop_last: Whether to drop the last incomplete batch
**kwargs: Additional parameters including dataset, length_bins, shuffle, etc.
"""
super().__init__(sampler, batch_size, drop_last)
# try to get params from kwargs or sampler
if 'dataset' in kwargs:
self.dataset = kwargs['dataset']
print(f" - dataset from kwargs: {type(self.dataset)}")
else:
if hasattr(sampler, 'dataset'):
self.dataset = sampler.dataset
print(f" - dataset from sampler.dataset: {type(self.dataset)}")
elif hasattr(sampler, 'sampler') and hasattr(sampler.sampler, 'dataset'):
# DistributedSamplerWrapper
self.dataset = sampler.sampler.dataset
self.batch_size = batch_size
# try to access params from kwargs
self.length_bins = kwargs.get('length_bins', 20)
self.shuffle = kwargs.get('shuffle', True)
self.seed = kwargs.get('seed', 42)
self.min_bin_size = kwargs.get('min_bin_size', 64)
# try to access kwargs or sampler
if 'num_replicas' in kwargs:
self.num_replicas = kwargs['num_replicas']
elif hasattr(sampler, 'num_replicas'):
self.num_replicas = sampler.num_replicas
elif hasattr(sampler, 'sampler') and hasattr(sampler.sampler, 'num_replicas'):
self.num_replicas = sampler.sampler.num_replicas
else:
self.num_replicas = 1
if 'rank' in kwargs:
self.rank = kwargs['rank']
elif hasattr(sampler, 'rank'):
self.rank = sampler.rank
elif hasattr(sampler, 'sampler') and hasattr(sampler.sampler, 'rank'):
self.rank = sampler.sampler.rank
else:
self.rank = 0
def _group_by_length(self) -> Dict[str, List[int]]:
"""Group dataset indices by sequence length using pre-computed lengths."""
length_groups = defaultdict(list)
print(" Grouping sequences by length...")
print(f" - min_bin_size: {self.min_bin_size}")
print(f" - sequence_lengths range: {min(self.dataset.sequence_lengths)} - {max(self.dataset.sequence_lengths)}")
total_samples = len(self.dataset)
samples_per_replica = total_samples // self.num_replicas
start_idx = self.rank * samples_per_replica
end_idx = start_idx + samples_per_replica if self.rank < self.num_replicas - 1 else total_samples
print(f" - Process {self.rank}/{self.num_replicas}: processing samples {start_idx}-{end_idx-1}")
for idx in range(start_idx, end_idx):
seq_len = self.dataset.sequence_lengths[idx]
bin_idx = seq_len // self.min_bin_size
bin_start = bin_idx * self.min_bin_size
bin_end = (bin_idx + 1) * self.min_bin_size - 1
length_range = f"{bin_start}-{bin_end}"
length_groups[length_range].append(idx)
return dict(length_groups)
def _create_batches(self) -> List[List[int]]:
"""Create batches within each length group."""
all_batches = []
for length_range, indices in self.length_groups.items():
if len(indices) == 0:
continue
if self.shuffle:
random.seed(self.seed)
random.shuffle(indices)
if len(indices) < self.batch_size:
copies_needed = (self.batch_size + len(indices) - 1) // len(indices)
extended_indices = []
for _ in range(copies_needed):
extended_indices.extend(indices)
extended_indices = extended_indices[:self.batch_size]
for i in range(0, len(extended_indices), self.batch_size):
batch = extended_indices[i:i + self.batch_size]
if len(batch) == self.batch_size:
all_batches.append(batch)
else:
for i in range(0, len(indices), self.batch_size):
batch = indices[i:i + self.batch_size]
if len(batch) == self.batch_size:
all_batches.append(batch)
elif len(batch) > 0:
if len(batch) < self.batch_size:
needed = self.batch_size - len(batch)
for j in range(needed):
batch.append(batch[j % len(batch)])
all_batches.append(batch)
if self.shuffle:
random.shuffle(all_batches)
return all_batches
def __iter__(self):
"""Iterate over batches using length-based grouping."""
# Get indices from the underlying sampler (already distributed by Lightning)
indices = list(self.sampler)
# Group indices by length
length_groups = defaultdict(list)
for idx in indices:
seq_len = self.dataset.sequence_lengths[idx]
bin_idx = seq_len // self.min_bin_size
bin_start = bin_idx * self.min_bin_size
bin_end = (bin_idx + 1) * self.min_bin_size - 1
length_range = f"{bin_start}-{bin_end}"
length_groups[length_range].append(idx)
# Create batches within each length group
batch_count = 0
for length_range, group_indices in length_groups.items():
if len(group_indices) == 0:
continue
if self.shuffle:
random.seed(self.seed)
random.shuffle(group_indices)
# Create batches from this length group
for i in range(0, len(group_indices), self.batch_size):
batch = group_indices[i:i + self.batch_size]
if len(batch) == self.batch_size or not self.drop_last:
batch_count += 1
yield batch
def __len__(self):
"""Number of batches."""
# Estimate number of batches based on dataset size and batch_size
total_samples = len(self.dataset)
print(f"🔍 Number of total samples in smart sampler in one device: {total_samples}")
return (total_samples + self.batch_size - 1) // self.batch_sizeError messages and logs
Here is my training loss logs:
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.3.0): 2.3.0
#- PyTorch Version (e.g., 2.5): 2.4
#- Python version (e.g., 3.10): 3.10
#- OS (e.g., Linux): Linux Debian
#- CUDA version: 11.8
#- GPU models and configuration: A800 x 4
#- How you installed Lightning(pip): pip
More info
No response