@@ -163,22 +163,6 @@ def forward(self, x):
163163except Exception :
164164 tb .print_exc ()
165165
166- ######################################################################
167- # - unsupported Python language features (e.g. throwing exceptions, match statements)
168-
169- class Bad4 (torch .nn .Module ):
170- def forward (self , x ):
171- try :
172- x = x + 1
173- raise RuntimeError ("bad" )
174- except :
175- x = x + 2
176- return x
177-
178- try :
179- export (Bad4 (), (torch .randn (3 , 3 ),))
180- except Exception :
181- tb .print_exc ()
182166
183167######################################################################
184168# Non-Strict Export
@@ -197,16 +181,6 @@ def forward(self, x):
197181# ``strict=False`` flag.
198182#
199183# Looking at some of the previous examples which resulted in graph breaks:
200- #
201- # - Accessing tensor data with ``.data`` now works correctly
202-
203- class Bad2 (torch .nn .Module ):
204- def forward (self , x ):
205- x .data [0 , 0 ] = 3
206- return x
207-
208- bad2_nonstrict = export (Bad2 (), (torch .randn (3 , 3 ),), strict = False )
209- print (bad2_nonstrict .module ()(torch .ones (3 , 3 )))
210184
211185######################################################################
212186# - Calling unsupported functions (such as many built-in functions) traces
@@ -223,22 +197,6 @@ def forward(self, x):
223197print (bad3_nonstrict )
224198print (bad3_nonstrict .module ()(torch .ones (3 , 3 )))
225199
226- ######################################################################
227- # - Unsupported Python language features (such as throwing exceptions, match
228- # statements) now also get traced through.
229-
230- class Bad4 (torch .nn .Module ):
231- def forward (self , x ):
232- try :
233- x = x + 1
234- raise RuntimeError ("bad" )
235- except :
236- x = x + 2
237- return x
238-
239- bad4_nonstrict = export (Bad4 (), (torch .randn (3 , 3 ),), strict = False )
240- print (bad4_nonstrict .module ()(torch .ones (3 , 3 )))
241-
242200
243201######################################################################
244202# However, there are still some features that require rewrites to the original
@@ -349,7 +307,7 @@ def forward(self, x, y):
349307# ``inp1`` has an unconstrained first dimension, but the size of the second
350308# dimension must be in the interval [4, 18].
351309
352- from torch .export import Dim
310+ from torch .export . dynamic_shapes import Dim
353311
354312inp1 = torch .randn (10 , 10 , 2 )
355313
@@ -358,7 +316,7 @@ def forward(self, x):
358316 x = x [:, 2 :]
359317 return torch .relu (x )
360318
361- inp1_dim0 = Dim ("inp1_dim0" )
319+ inp1_dim0 = Dim ("inp1_dim0" , max = 50 )
362320inp1_dim1 = Dim ("inp1_dim1" , min = 4 , max = 18 )
363321dynamic_shapes1 = {
364322 "x" : {0 : inp1_dim0 , 1 : inp1_dim1 },
@@ -479,9 +437,7 @@ def forward(self, z, y):
479437
480438class DynamicShapesExample3 (torch .nn .Module ):
481439 def forward (self , x , y ):
482- if x .shape [0 ] <= 16 :
483- return x @ y [:, :16 ]
484- return y
440+ return x @ y
485441
486442dynamic_shapes3 = {
487443 "x" : {i : Dim (f"inp4_dim{ i } " ) for i in range (inp4 .dim ())},
@@ -536,6 +492,28 @@ def suggested_fixes():
536492
537493print (exported_dynamic_shapes_example3 .range_constraints )
538494
495+ ######################################################################
496+ # In PyTorch v2.5, we also introduced an automatic way of determining dynamic
497+ # shapes. In the case where you don't know the dynamism of tensors, or the
498+ # relationship of dynamic shapes between input tensors, we can mark dimensions
499+ # with `Dim.AUTO`, and export will determine the dynamism the input dimensions.
500+ # Going back to the previous example, we can rewrite it as follows:
501+
502+ inp4 = torch .randn (8 , 16 )
503+ inp5 = torch .randn (16 , 32 )
504+
505+ class DynamicShapesExample3 (torch .nn .Module ):
506+ def forward (self , x , y ):
507+ return x @ y
508+
509+ dynamic_shapes3_2 = {
510+ "x" : {i : Dim .AUTO for i in range (inp4 .dim ())},
511+ "y" : {i : Dim .AUTO for i in range (inp5 .dim ())},
512+ }
513+
514+ exported_dynamic_shapes_example_3_2 = export (DynamicShapesExample3 (), (inp4 , inp5 ), dynamic_shapes = dynamic_shapes3_2 )
515+ print (exported_dynamic_shapes_example_3_2 )
516+
539517######################################################################
540518# Custom Ops
541519# ----------
@@ -548,7 +526,7 @@ def suggested_fixes():
548526# as with any other custom op
549527
550528@torch .library .custom_op ("my_custom_library::custom_op" , mutates_args = {})
551- def custom_op (input : torch .Tensor ) -> torch .Tensor :
529+ def custom_op (x : torch .Tensor ) -> torch .Tensor :
552530 print ("custom_op called!" )
553531 return torch .relu (x )
554532
0 commit comments