[Q] Can someone explain the logic behind tbptt_split_batch's splitting dimension? #10086
Unanswered
garrett361
asked this question in
Lightning Trainer API: Trainer, LightningModule, LightningDataModule
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I am confused about the code in
tbptt_split_batchreferenced at the bottom of this post, specifically why it is apparently assumed that thexinfor x in batchwould have their time-dimension be at1.I would have thought that
tbptt_split_batchwould be designed withtorch'sRNN,GRU, orLSTMin mind (withbatch_first=True), in which case I would expectbatchto be of shape(batch_size, sequence_length, input_size). But if this were the case, then thetime_dims = [len(x[0]) for x in batch]would mistakenly takeinput_sizeto be the size of the time dimension and the rest oftbptt_split_batchwould split along this dimension, rather than thesequence_lengthdimension, no?Came across the above when attempting to use
tbptt_split_batchfor a custom dataset. Needed to overwrite with the splitting over the dimension which I expected, as outline above, and it seems to work correctly.I feel like I'm badly misunderstanding something.
https://github.com/PyTorchLightning/pytorch-lightning/blob/c9bc10ce8473a2249ffa4e00972c0c3c1d2641c4/pytorch_lightning/core/lightning.py#L1720-L1739
Edit addition for clarity: in the above code I assume that
batchis a(b, t, d)shaped input tensor with these three numbers being the batch size, sequence length, and input dimension, respectively.Then, it would seem that
len(x[0])would gived, rather than the expectedt, for eachx. I would have thought that the above should readwhich is what I have in my own
tbptt_split_batchmethods. I guess different assumptions are being made about the shape of the input tensor?Beta Was this translation helpful? Give feedback.
All reactions