@@ -282,7 +282,10 @@ def forward(
282282model = DynamicModel ()
283283ep = export (model , (w , x , y , z ))
284284model (w , x , torch .randn (3 , 4 ), torch .randn (12 ))
285- ep .module ()(w , x , torch .randn (3 , 4 ), torch .randn (12 ))
285+ try :
286+ ep .module ()(w , x , torch .randn (3 , 4 ), torch .randn (12 ))
287+ except Exception :
288+ tb .print_exc ()
286289
287290######################################################################
288291# Basic concepts: symbols and guards
@@ -411,7 +414,10 @@ def forward(
411414# static guard is emitted on a dynamically-marked dimension:
412415
413416dynamic_shapes ["w" ] = (Dim .AUTO , Dim .DYNAMIC )
414- export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
417+ try :
418+ export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
419+ except Exception :
420+ tb .print_exc ()
415421
416422######################################################################
417423# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
@@ -421,7 +427,10 @@ def forward(
421427dynamic_shapes ["w" ] = (Dim .AUTO , Dim .AUTO )
422428dynamic_shapes ["x" ] = (Dim .STATIC ,)
423429dynamic_shapes ["y" ] = (Dim .AUTO , Dim .DYNAMIC )
424- export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
430+ try :
431+ export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
432+ except Exception :
433+ tb .print_exc ()
425434
426435######################################################################
427436# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
@@ -439,7 +448,7 @@ def __init__(self):
439448
440449 def forward (self , w , x , y , z ):
441450 assert w .shape [0 ] <= 512
442- torch ._check (x .shape [0 ] >= 16 )
451+ torch ._check (x .shape [0 ] >= 4 )
443452 if w .shape [0 ] == x .shape [0 ] + 2 :
444453 x0 = x + y
445454 x1 = self .l (w )
@@ -455,8 +464,10 @@ def forward(self, w, x, y, z):
455464 "y" : (Dim .AUTO , Dim .AUTO ),
456465 "z" : (Dim .AUTO ,),
457466}
458- ep = export (DynamicModel (), (w , x , y , z ), dynamic_shapes = dynamic_shapes )
459- print (ep )
467+ try :
468+ ep = export (DynamicModel (), (w , x , y , z ), dynamic_shapes = dynamic_shapes )
469+ except Exception :
470+ tb .print_exc ()
460471
461472######################################################################
462473# Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``,
@@ -485,7 +496,10 @@ def forward(self, w, x, y, z):
485496 "input" : (Dim .AUTO , Dim .STATIC ),
486497 },
487498)
488- ep .module ()(torch .randn (2 , 4 ))
499+ try :
500+ ep .module ()(torch .randn (2 , 4 ))
501+ except Exception :
502+ tb .print_exc ()
489503
490504######################################################################
491505# Named Dims
@@ -539,14 +553,17 @@ def forward(self, x, y):
539553 return w + torch .ones (4 )
540554
541555dx , dy , d1 = torch .export .dims ("dx" , "dy" , "d1" )
542- ep = export (
543- Foo (),
544- (torch .randn (6 , 4 ), torch .randn (6 , 4 )),
545- dynamic_shapes = {
546- "x" : (dx , d1 ),
547- "y" : (dy , d1 ),
548- },
549- )
556+ try :
557+ ep = export (
558+ Foo (),
559+ (torch .randn (6 , 4 ), torch .randn (6 , 4 )),
560+ dynamic_shapes = {
561+ "x" : (dx , d1 ),
562+ "y" : (dy , d1 ),
563+ },
564+ )
565+ except Exception :
566+ tb .print_exc ()
550567
551568######################################################################
552569# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
@@ -688,7 +705,10 @@ def forward(self, x, y):
688705 torch .tensor (32 ),
689706 torch .randn (60 ),
690707)
691- export (Foo (), inps )
708+ try :
709+ export (Foo (), inps )
710+ except Exception :
711+ tb .print_exc ()
692712
693713######################################################################
694714# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with
@@ -700,7 +720,7 @@ class Foo(torch.nn.Module):
700720 def forward (self , x , y ):
701721 a = x .item ()
702722 torch ._check (a >= 0 )
703- torch ._check (a <= y .shape [0 ])
723+ torch ._check (a < y .shape [0 ])
704724 return y [a ]
705725
706726inps = (
@@ -732,7 +752,10 @@ def forward(self, x, y):
732752 torch .tensor (32 ),
733753 torch .randn (60 ),
734754)
735- export (Foo (), inps , strict = False )
755+ try :
756+ export (Foo (), inps , strict = False )
757+ except Exception :
758+ tb .print_exc ()
736759
737760######################################################################
738761# For these errors, some basic options you have are:
0 commit comments