1919from tf2onnx import utils
2020from tf2onnx .handler import tf_op
2121from tf2onnx .tf_loader import find_function
22+ from tf2onnx .graph_builder import GraphBuilder
2223
2324
2425logger = logging .getLogger (__name__ )
@@ -401,6 +402,7 @@ def version_7(cls, ctx, node, **kwargs):
401402 cond_input_to_state_var = {}
402403 scan_outputs = []
403404 input_idx_to_remove = []
405+ idx_to_ragged_writes = dict (body .ragged_variant_list_writes )
404406 # remove TensorListReserve
405407 for idx , name in enumerate (tf_while_inputs ):
406408 if idx == 1 :
@@ -416,9 +418,15 @@ def version_7(cls, ctx, node, **kwargs):
416418 # there is no equivalent step in onnx and we should remove it.
417419 output_shape = None
418420 output_dtype = n .get_attr_value ("element_dtype" )
421+ is_ragged = False
419422 if n .type == "TensorListReserve" and n .inputs [0 ].is_const () and not n .inputs [0 ].is_scalar ():
420423 output_shape = [- 1 ] + n .inputs [0 ].get_tensor_value (as_list = True )
421- scan_outputs .append ((idx , n , output_shape , output_dtype ))
424+ if idx in idx_to_ragged_writes :
425+ output_shape = None
426+ output_dtype = body .get_dtype (idx_to_ragged_writes [idx ].input [0 ])
427+ is_ragged = True
428+ loop_vars .append (name )
429+ scan_outputs .append ((idx , n , output_shape , output_dtype , is_ragged ))
422430 continue
423431
424432 # tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
@@ -437,8 +445,29 @@ def version_7(cls, ctx, node, **kwargs):
437445 del body .outputs [idx ]
438446
439447 scan_output_names = []
440- # remove tensor array that are passed in to the loop
441- for idx , n , output_shape , output_dtype in reversed (scan_outputs ):
448+ ragged_scan_output_names = []
449+ ragged_scan_output_to_len = {}
450+
451+ # remove tensor arrays that are passed in to the loop
452+ for idx , n , output_shape , output_dtype , is_ragged in reversed (scan_outputs ):
453+ if is_ragged :
454+ out = n .output [0 ]
455+ ctx .remove_node (n .name )
456+ seq_empty = ctx .make_node ("SequenceEmpty" , [], attr = {'dtype' : output_dtype }, name = n .name ,
457+ outputs = [out ], shapes = [None ], dtypes = [utils .SeqType (output_dtype )])
458+ ctx .replace_all_inputs (n .output [0 ], seq_empty .output [0 ])
459+ # Ragged tensors also must track the length of each row
460+ output_shapes .append ([- 1 ])
461+ output_dtypes .append (TensorProto .INT64 )
462+ output_shapes [idx ] = None
463+ output_dtypes [idx ] = utils .SeqType (output_dtype )
464+ body_ragged_name = utils .make_name ("ragged_scan_output" )
465+ external_ragged_name = utils .make_name ("ragged_output" )
466+ scan_output_names .append (body_ragged_name )
467+ output_names .append (external_ragged_name )
468+ ragged_scan_output_names .append (body_ragged_name )
469+ ragged_scan_output_to_len [output_names [idx ]] = external_ragged_name
470+ continue
442471 ctx .remove_node (n .name )
443472 # make the node output bad
444473 ctx .replace_all_inputs (n .output [0 ], "@@ALLOC" ) # ops=ctx.get_nodes()
@@ -475,11 +504,16 @@ def version_7(cls, ctx, node, **kwargs):
475504
476505 # shift output consumers
477506 for k , v in output_map .items ():
478- ctx .replace_all_inputs (k , v ) # ops=ctx.get_nodes()
507+ if k not in ragged_scan_output_to_len .values ():
508+ ctx .replace_all_inputs (k , v ) # ops=ctx.get_nodes()
509+
510+ ragged_scan_output_to_len = {output_map [k ]: output_map [v ] for k , v in ragged_scan_output_to_len .items ()}
479511
480512 wire_while_body (ctx , body , loop_node , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
481- output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , scan_output_names )
513+ output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , scan_output_names ,
514+ ragged_scan_output_names )
482515
516+ loop_node .ragged_scan_output_to_len = ragged_scan_output_to_len
483517 # if there was a tensorflow variant type, bind in a real type here
484518 # FIXME: I don't think this is needed anymore
485519 for i , n in enumerate (body .inputs ):
@@ -488,7 +522,8 @@ def version_7(cls, ctx, node, **kwargs):
488522
489523
490524def wire_while_body (parent_g , g , loop_node , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
491- output_dtypes , scope , parent , cond_graph , tf_while_inputs , scan_output_names ):
525+ output_dtypes , scope , parent , cond_graph , tf_while_inputs , scan_output_names ,
526+ ragged_scan_output_names ):
492527 """Wire subgraph graph into main."""
493528 remove_parents = []
494529 to_remove = []
@@ -519,8 +554,25 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
519554
520555 # this is a tensor array write - make it an identity
521556 scan_outputs = []
557+ ragged_scan_outputs_cnt = 0
558+ names_to_scan_outputs = {}
559+
522560 for node in g .get_nodes ():
523561 if node .type == "TensorListSetItem" :
562+ if node .inputs [2 ].type == "RaggedTensorToVariant" :
563+ node .type = "SequenceInsert"
564+ row_content = node .inputs [2 ].input [0 ]
565+ g .replace_inputs (node , [node .input [0 ], row_content ])
566+ g .set_shape (node .output [0 ], g .get_shape (node .input [1 ]))
567+ g .set_dtype (node .output [0 ], utils .SeqType (g .get_dtype (node .input [1 ])))
568+ dense_shape = g .make_node ("Shape" , [row_content ]).output [0 ]
569+ zero_const = g .make_const (utils .make_name ("zero_const" ), np .array (0 , np .int64 )).output [0 ]
570+ row_length = g .make_node ("Gather" , [dense_shape , zero_const ]).output [0 ]
571+ row_length_id = g .make_node ("Identity" , [row_length ])
572+ scan_outputs .append (row_length_id .output [0 ])
573+ names_to_scan_outputs [ragged_scan_output_names [ragged_scan_outputs_cnt ]] = row_length_id .output [0 ]
574+ ragged_scan_outputs_cnt += 1
575+ continue
524576 remove_parents .append (node .input [0 ])
525577 node .type = "Identity"
526578 g .set_shape (node .output [0 ], g .get_shape (node .input [2 ]))
@@ -531,8 +583,9 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
531583 if len (scan_outputs ) != len (scan_output_names ):
532584 raise ValueError ("While loop couldn't find scan output index for nodes" )
533585
534- names_to_scan_outputs = {}
535586 for output in scan_outputs :
587+ if output in names_to_scan_outputs .values ():
588+ continue
536589 last_output = output
537590 consumers = g .find_output_consumers (last_output )
538591 while consumers :
@@ -547,8 +600,9 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
547600
548601 # Reorder scan outputs
549602 scan_outputs = [names_to_scan_outputs [name ] for name in scan_output_names ]
603+
604+ # Use shapes from subgraph if loop node shapes for scan outputs are missing
550605 for i in range (- len (scan_output_names ), 0 ):
551- # Use shapes from subgraph if loop node shapes for scan outputs are missing
552606 if loop_node .output_shapes [i ] is None :
553607 shape = g .get_shape (scan_outputs [i ])
554608 if shape is not None :
@@ -580,6 +634,31 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
580634 if node .type in ["Identity" ]:
581635 g .set_dtype (o , node .inputs [0 ].output_dtypes [0 ])
582636
637+ for node in g .ragged_variant_list_reads :
638+ # Requires opset 11
639+ gather = node .inputs [0 ]
640+ inp = gather .inputs [0 ]
641+ while inp .type == "Identity" :
642+ inp = inp .inputs [0 ]
643+ err_msg1 = "Could not find corresponding RaggedTensorToVariant for node %s" % node .name
644+ err_msg2 = "Input to RaggedTensorToVariant for loop has batched_input=False for node %s" % inp .name
645+ err_msg3 = "RAGGED_RANK != 1 for RaggedTensorToVariant node %s" % node .name
646+ utils .make_sure (inp .type == "RaggedTensorToVariant" , err_msg1 )
647+ utils .make_sure (inp .get_attr_value ("batched_input" ), err_msg2 )
648+ utils .make_sure (inp .get_attr_value ("RAGGED_RANK" ) == 1 , err_msg3 )
649+ idx = gather .input [1 ]
650+ idx_unsq = GraphBuilder (g ).make_unsqueeze ({'data' : idx , 'axes' : [0 ]})
651+ np_dtype = utils .map_onnx_to_numpy_type (g .get_dtype (idx_unsq ))
652+ const_one = g .make_const (utils .make_name ("const_1" ), np .array (1 , np_dtype )).output [0 ]
653+ idx_plus_1 = g .make_node ("Add" , [idx_unsq , const_one ]).output [0 ]
654+ splits , values = inp .input
655+ start = g .make_node ("Gather" , [splits , idx_unsq ]).output [0 ]
656+ end = g .make_node ("Gather" , [splits , idx_plus_1 ]).output [0 ]
657+ np_dtype2 = utils .map_onnx_to_numpy_type (g .get_dtype (splits ))
658+ axes = g .make_const (utils .make_name ("const_zero" ), np .array ([0 ], np_dtype2 )).output [0 ]
659+ sliced_vals = g .make_node ("Slice" , [values , start , end , axes ]).output [0 ]
660+ g .replace_all_inputs (node .output [0 ], sliced_vals )
661+
583662 return g
584663
585664
0 commit comments