@@ -1818,6 +1818,45 @@ def version_9(cls, ctx, node, **kwargs):
18181818 ctx .remove_node (node .name )
18191819
18201820
1821+ @tf_op (["DynamicStitch" , "ParallelDynamicStitch" ])
1822+ class DynamicStitch :
1823+ @classmethod
1824+ def version_10 (cls , ctx , node , ** kwargs ):
1825+ num_partitions = len (node .input ) // 2
1826+ index_inputs = node .input [:num_partitions ]
1827+ data_inputs = node .input [num_partitions :]
1828+ index_shapes = [ctx .get_shape (inp ) for inp in index_inputs ]
1829+ data_shapes = [ctx .get_shape (inp ) for inp in data_inputs ]
1830+ utils .make_sure (all (s is not None and len (s ) == 1 for s in index_shapes ),
1831+ "DynamicPartition only implemented for index tensors of rank 1" )
1832+ utils .make_sure (all (s is not None and len (s ) == 1 for s in data_shapes ),
1833+ "DynamicPartition only implemented for data tensors of rank 1" )
1834+ dtype = ctx .get_dtype (node .output [0 ])
1835+ concat_indices = ctx .make_node ("Concat" , index_inputs , attr = {'axis' : 0 })
1836+ concat_indices_int64 = ctx .make_node ("Cast" , [concat_indices .output [0 ]], attr = {"to" : TensorProto .INT64 })
1837+
1838+ concat_data = ctx .make_node ("Concat" , data_inputs , attr = {'axis' : 0 })
1839+
1840+ data_shape = ctx .make_node ("Shape" , [concat_data .output [0 ]])
1841+ expanded_indices = ctx .make_node ("Expand" , [concat_indices_int64 .output [0 ], data_shape .output [0 ]])
1842+
1843+ max_index = ctx .make_node ("ReduceMax" , [concat_indices_int64 .output [0 ]], attr = {'axes' : [0 ], 'keepdims' : 1 })
1844+ const_one = ctx .make_const (utils .make_name ('const_one' ), np .array ([1 ], np .int64 ))
1845+ target_length = ctx .make_node ("Add" , [max_index .output [0 ], const_one .output [0 ]])
1846+
1847+ zero_tensor = helper .make_tensor ("value" , dtype , dims = [1 ], vals = [0 ])
1848+ zeros_of_shape = ctx .make_node ("ConstantOfShape" , [target_length .output [0 ]], attr = {"value" : zero_tensor })
1849+
1850+ name = node .name
1851+ outputs = node .output
1852+ ctx .remove_node (node .name )
1853+ ctx .make_node ("ScatterElements" ,
1854+ [zeros_of_shape .output [0 ], expanded_indices .output [0 ], concat_data .output [0 ]],
1855+ name = name ,
1856+ outputs = outputs ,
1857+ attr = {'axis' : 0 })
1858+
1859+
18211860@tf_op ("MatrixDiagPart" )
18221861class MatrixDiagPart :
18231862 @classmethod
0 commit comments