@@ -282,7 +282,10 @@ def forward(
282282model = DynamicModel ()
283283ep = export (model , (w , x , y , z ))
284284model (w , x , torch .randn (3 , 4 ), torch .randn (12 ))
285- ep .module ()(w , x , torch .randn (3 , 4 ), torch .randn (12 ))
285+ try :
286+ ep .module ()(w , x , torch .randn (3 , 4 ), torch .randn (12 ))
287+ except Exception :
288+ tb .print_exc ()
286289
287290######################################################################
288291# Basic concepts: symbols and guards
@@ -411,7 +414,10 @@ def forward(
411414# static guard is emitted on a dynamically-marked dimension:
412415
413416dynamic_shapes ["w" ] = (Dim .AUTO , Dim .DYNAMIC )
414- export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
417+ try :
418+ export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
419+ except Exception :
420+ tb .print_exc ()
415421
416422######################################################################
417423# Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape
@@ -421,7 +427,10 @@ def forward(
421427dynamic_shapes ["w" ] = (Dim .AUTO , Dim .AUTO )
422428dynamic_shapes ["x" ] = (Dim .STATIC ,)
423429dynamic_shapes ["y" ] = (Dim .AUTO , Dim .DYNAMIC )
424- export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
430+ try :
431+ export (model , (w , x , y , z ), dynamic_shapes = dynamic_shapes )
432+ except Exception :
433+ tb .print_exc ()
425434
426435######################################################################
427436# Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because
@@ -439,7 +448,7 @@ def __init__(self):
439448
440449 def forward (self , w , x , y , z ):
441450 assert w .shape [0 ] <= 512
442- torch ._check (x .shape [0 ] >= 16 )
451+ torch ._check (x .shape [0 ] >= 4 )
443452 if w .shape [0 ] == x .shape [0 ] + 2 :
444453 x0 = x + y
445454 x1 = self .l (w )
@@ -455,7 +464,10 @@ def forward(self, w, x, y, z):
455464 "y" : (Dim .AUTO , Dim .AUTO ),
456465 "z" : (Dim .AUTO ,),
457466}
458- ep = export (DynamicModel (), (w , x , y , z ), dynamic_shapes = dynamic_shapes )
467+ try :
468+ ep = export (DynamicModel (), (w , x , y , z ), dynamic_shapes = dynamic_shapes )
469+ except Exception :
470+ tb .print_exc ()
459471print (ep )
460472
461473######################################################################
@@ -485,7 +497,10 @@ def forward(self, w, x, y, z):
485497 "input" : (Dim .AUTO , Dim .STATIC ),
486498 },
487499)
488- ep .module ()(torch .randn (2 , 4 ))
500+ try :
501+ ep .module ()(torch .randn (2 , 4 ))
502+ except Exception :
503+ tb .print_exc ()
489504
490505######################################################################
491506# Named Dims
@@ -539,14 +554,17 @@ def forward(self, x, y):
539554 return w + torch .ones (4 )
540555
541556dx , dy , d1 = torch .export .dims ("dx" , "dy" , "d1" )
542- ep = export (
543- Foo (),
544- (torch .randn (6 , 4 ), torch .randn (6 , 4 )),
545- dynamic_shapes = {
546- "x" : (dx , d1 ),
547- "y" : (dy , d1 ),
548- },
549- )
557+ try :
558+ ep = export (
559+ Foo (),
560+ (torch .randn (6 , 4 ), torch .randn (6 , 4 )),
561+ dynamic_shapes = {
562+ "x" : (dx , d1 ),
563+ "y" : (dy , d1 ),
564+ },
565+ )
566+ except Exception :
567+ tb .print_exc ()
550568
551569######################################################################
552570# The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards.
@@ -688,7 +706,10 @@ def forward(self, x, y):
688706 torch .tensor (32 ),
689707 torch .randn (60 ),
690708)
691- export (Foo (), inps )
709+ try :
710+ export (Foo (), inps )
711+ except Exception :
712+ tb .print_exc ()
692713
693714######################################################################
694715# Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with
@@ -700,7 +721,7 @@ class Foo(torch.nn.Module):
700721 def forward (self , x , y ):
701722 a = x .item ()
702723 torch ._check (a >= 0 )
703- torch ._check (a <= y .shape [0 ])
724+ torch ._check (a < y .shape [0 ])
704725 return y [a ]
705726
706727inps = (
@@ -732,7 +753,10 @@ def forward(self, x, y):
732753 torch .tensor (32 ),
733754 torch .randn (60 ),
734755)
735- export (Foo (), inps , strict = False )
756+ try :
757+ export (Foo (), inps , strict = False )
758+ except Exception :
759+ tb .print_exc ()
736760
737761######################################################################
738762# For these errors, some basic options you have are:
0 commit comments