File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -163,20 +163,20 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
163163
164164
165165@pytest .mark .parametrize ('bsize' , [5 , 10 ])
166- def test_batch_3d_squeeze_batch_dim ( sample_ds_3d , bsize ):
166+ def test_batch_1d_squeeze_batch_dim ( sample_ds_1d , bsize ):
167167 xbsize = 20
168168 bg = BatchGenerator (
169- sample_ds_3d ,
170- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
169+ sample_ds_1d ,
170+ input_dims = {'x' : xbsize },
171171 squeeze_batch_dim = False ,
172172 )
173173 for ds_batch in bg :
174- assert ds_batch ['x ' ].shape == [1 , bsize , xbsize ]
174+ assert list ( ds_batch ['foo ' ].shape ) == [1 , xbsize ]
175175
176176 bg2 = BatchGenerator (
177- sample_ds_3d ,
178- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
177+ sample_ds_1d ,
178+ input_dims = {'x' : xbsize },
179179 squeeze_batch_dim = True ,
180180 )
181- for ds_batch in bg :
182- assert ds_batch ['x ' ].shape == [bsize , xbsize ]
181+ for ds_batch in bg2 :
182+ assert list ( ds_batch ['foo ' ].shape ) == [xbsize ]
You can’t perform that action at this time.
0 commit comments