Skip to content

Training loss will increase drastically when resume from ckpts after adopting self-defined sampler in multi-GPU experiments [Fabric] #21312

@Moondok

Description

@Moondok

Bug description

I'm doing a sequence modelling tasks, and my dataset is consisted of many length variant samples. To avoid too many EOS padding, I implemented a sampler to allocate samples with similar lengths into a batch. Nevertheless, I found when I resume the ckpts, the loss will increase dramatically. ( I have checked that this error is not caused by precision, save/load-related problems). What's more, this error only occurs in multi-GPU experiments.

As for implementation details, I noted the lightning fabric will reinstantialize the sampler. To correctly pass the parameters the sampler needs, I pass the parameters through a packed kwargs and let my sampler inherited from torch.utils.data.BatchSampler. please see my codes below.

I do not know whether my code is wrong or there's some hidden bugs in fabric, I will appreciate if you can help me.

What version are you seeing the problem on?

v2.3

Reproduced in studio

No response

How to reproduce the bug

import torch
class SmartSampler(BatchSampler):
    """
    Smart sampler that groups sequences by length to minimize padding.
    
    This sampler reduces the amount of EOS tokens needed per batch by grouping
    sequences of similar lengths together, significantly improving training efficiency.
    """
    
    def __init__(self, sampler: Sampler, batch_size: int = 8, drop_last: bool = False, **kwargs):
        """
        Args:
            sampler: Base sampler to wrap (can be DistributedSamplerWrapper in multi-GPU)
            batch_size: Number of samples per batch
            drop_last: Whether to drop the last incomplete batch
            **kwargs: Additional parameters including dataset, length_bins, shuffle, etc.
        """
        super().__init__(sampler, batch_size, drop_last)
        
        # try to get params from kwargs or sampler
        if 'dataset' in kwargs:
            self.dataset = kwargs['dataset']
            print(f"   - dataset from kwargs: {type(self.dataset)}")
        else:
            if hasattr(sampler, 'dataset'):
                self.dataset = sampler.dataset
                print(f"   - dataset from sampler.dataset: {type(self.dataset)}")
            elif hasattr(sampler, 'sampler') and hasattr(sampler.sampler, 'dataset'):
                # DistributedSamplerWrapper 
                self.dataset = sampler.sampler.dataset
        
        self.batch_size = batch_size
        
        # try to access params from kwargs
        self.length_bins = kwargs.get('length_bins', 20)
        self.shuffle = kwargs.get('shuffle', True)
        self.seed = kwargs.get('seed', 42)
        self.min_bin_size = kwargs.get('min_bin_size', 64)
        
        # try to access kwargs or sampler
        if 'num_replicas' in kwargs:
            self.num_replicas = kwargs['num_replicas']
        elif hasattr(sampler, 'num_replicas'):
            self.num_replicas = sampler.num_replicas
        elif hasattr(sampler, 'sampler') and hasattr(sampler.sampler, 'num_replicas'):
            self.num_replicas = sampler.sampler.num_replicas
        else:
            self.num_replicas = 1
            
        if 'rank' in kwargs:
            self.rank = kwargs['rank']
        elif hasattr(sampler, 'rank'):
            self.rank = sampler.rank
        elif hasattr(sampler, 'sampler') and hasattr(sampler.sampler, 'rank'):
            self.rank = sampler.sampler.rank
        else:
            self.rank = 0
      
    def _group_by_length(self) -> Dict[str, List[int]]:
        """Group dataset indices by sequence length using pre-computed lengths."""
        length_groups = defaultdict(list)
        
        print(" Grouping sequences by length...")
        print(f"   - min_bin_size: {self.min_bin_size}")
        print(f"   - sequence_lengths range: {min(self.dataset.sequence_lengths)} - {max(self.dataset.sequence_lengths)}")
        
        total_samples = len(self.dataset)
        samples_per_replica = total_samples // self.num_replicas
        start_idx = self.rank * samples_per_replica
        end_idx = start_idx + samples_per_replica if self.rank < self.num_replicas - 1 else total_samples
        
        print(f"   - Process {self.rank}/{self.num_replicas}: processing samples {start_idx}-{end_idx-1}")
        
        for idx in range(start_idx, end_idx):
            seq_len = self.dataset.sequence_lengths[idx]
            bin_idx = seq_len // self.min_bin_size
            bin_start = bin_idx * self.min_bin_size
            bin_end = (bin_idx + 1) * self.min_bin_size - 1
            length_range = f"{bin_start}-{bin_end}"
            length_groups[length_range].append(idx)
     
        
        return dict(length_groups)
    
    def _create_batches(self) -> List[List[int]]:
        """Create batches within each length group."""
        all_batches = []
        
        for length_range, indices in self.length_groups.items():
            if len(indices) == 0:
                continue
                
            if self.shuffle:
                random.seed(self.seed)
                random.shuffle(indices)
            
            
            if len(indices) < self.batch_size:
               
                copies_needed = (self.batch_size + len(indices) - 1) // len(indices)  
                extended_indices = []
                
                for _ in range(copies_needed):
                    extended_indices.extend(indices)
                
                extended_indices = extended_indices[:self.batch_size]
                
                for i in range(0, len(extended_indices), self.batch_size):
                    batch = extended_indices[i:i + self.batch_size]
                    if len(batch) == self.batch_size:
                        all_batches.append(batch)
            else:
                for i in range(0, len(indices), self.batch_size):
                    batch = indices[i:i + self.batch_size]
                    if len(batch) == self.batch_size:
                        all_batches.append(batch)
                    elif len(batch) > 0:
                        if len(batch) < self.batch_size:
                            
                            needed = self.batch_size - len(batch)
                            for j in range(needed):
                                batch.append(batch[j % len(batch)])
                        
                        all_batches.append(batch)
        
        if self.shuffle:
            random.shuffle(all_batches)
        
        return all_batches
    
    def __iter__(self):
        """Iterate over batches using length-based grouping."""
        # Get indices from the underlying sampler (already distributed by Lightning)
        indices = list(self.sampler)
        
        # Group indices by length
        length_groups = defaultdict(list)
        for idx in indices:
            seq_len = self.dataset.sequence_lengths[idx]
            bin_idx = seq_len // self.min_bin_size
            bin_start = bin_idx * self.min_bin_size
            bin_end = (bin_idx + 1) * self.min_bin_size - 1
            length_range = f"{bin_start}-{bin_end}"
            length_groups[length_range].append(idx)
       
        
        # Create batches within each length group
        batch_count = 0
        for length_range, group_indices in length_groups.items():
            if len(group_indices) == 0:
                continue
                
            if self.shuffle:
                random.seed(self.seed)
                random.shuffle(group_indices)
            
            # Create batches from this length group
            for i in range(0, len(group_indices), self.batch_size):
                batch = group_indices[i:i + self.batch_size]
                if len(batch) == self.batch_size or not self.drop_last:
                    batch_count += 1
                    yield batch
    
    def __len__(self):
        """Number of batches."""
        # Estimate number of batches based on dataset size and batch_size
        total_samples = len(self.dataset)

        print(f"🔍 Number of total samples in smart sampler in one device: {total_samples}")
        return (total_samples + self.batch_size - 1) // self.batch_size

Error messages and logs

Here is my training loss logs:

Image

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.3.0): 2.3.0
#- PyTorch Version (e.g., 2.5): 2.4
#- Python version (e.g., 3.10): 3.10
#- OS (e.g., Linux): Linux Debian
#- CUDA version: 11.8
#- GPU models and configuration: A800 x 4
#- How you installed Lightning(pip): pip

More info

No response

cc @ethanwharris @lantiga @justusschock

Metadata

Metadata

Assignees

No one assigned

    Labels

    accelerator: cudaCompute Unified Device Architecture GPUbugSomething isn't workingcheckpointingRelated to checkpointingfabriclightning.fabric.Fabricver: 2.3.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions