Accumulating Batches (not Gradients) via custom Loops and avoid CUDA OOM #15116
Unanswered
myscience
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi everyone,
I am facing a CUDA Out-Of-Memory error when training my model on a custom dataset via DDP on 4 GPUs and would like to ear whether I am missing a simple solution here or where my mistake is.
The problem is that a single example in my dataset is quite big (a tensor of shape
(batch, 80, 105, 85)) and the model itself has two sub-modules, one of which is aBEiTtransformer (taken from HuggingFace), which has~85.7Mparameter. The Lightning report estimates the full model to be185MB. On a single GPU (which has 16GB of memory) I can fit a batch size of2, which is too small as I am using a contrastive-learning approach where in each batch I need positive and negative examples and I think a batch size of128should be more reasonable. My idea for solving this issue was the following. Before the loss computation, the model computes some vector representations of the data which are far smaller (tensors of shape(batch, 700)), so I can accumulate some batches, collect the vector representations till I reach something like(128, 700)and then compute the loss and update everything. The question is: "Does this makes sense? And if so, how can I achieve this sort of behavior?".As I understood it, the
Lightning APIeasily offer gradient accumulation, but I fear it is not useful here. In gradient accumulation the loss is computed on the individual mini-batches separately and then the gradients are accumulated. For me this would result in very poor individual gradients. After some investigation I found out about the LightningLoop APIand I thought I could use that to fit my needs. The idea was to subclass theTrainingEpochLoopand request multiple batches from thedata_fetcherusing a generator (so that we only have one or two examples in memory at a time) and use thelightning_model_hookon_train_batch_startto pre-process the batch and transform the(1, 80, 105, 85)tensor into the more manageable(1, 700)tensor and then start accumulating those. What my code is doing at the moment looks something like the following.In my
LightningModuleI have implemented theon_train_batch_starthook as follows (note that theexample2latentfunction is calling one submodule of my model):Finally in the main script I simply connect the custom loop as:
The problem with all of this is that if I try to use
accumulate = 16for example (thus aiming for a final latent vector of shape(16, 700)) I get the Out Of Memory Error I mentioned at the beginning. How can it be? Is this whole logic wrong? Do you guys have a more general suggestion on how to tackle this problem? Thanks!P.S. I also tried to turn my
torch.Datasetinto atorch.IterableDatasetand mess around withprefect_factorsandnum_workersand so on without luck.Beta Was this translation helpful? Give feedback.
All reactions