44from .base_emitter import BaseEmitter
55
66_types = {
7+ TensorProto .DOUBLE : "DOUBLE" ,
78 TensorProto .FLOAT : "FLOAT" ,
89 TensorProto .FLOAT16 : "FLOAT16" ,
910 TensorProto .INT64 : "INT64" ,
1011 TensorProto .INT32 : "INT32" ,
12+ TensorProto .INT16 : "INT16" ,
13+ TensorProto .UINT64 : "UINT64" ,
14+ TensorProto .UINT32 : "UINT32" ,
15+ TensorProto .UINT16 : "UINT16" ,
16+ TensorProto .STRING : "STRING" ,
17+ TensorProto .BOOL : "BOOL" ,
1118}
1219
1320
@@ -98,6 +105,7 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
98105 name = kwargs ["name" ]
99106 itype = kwargs .get ("elem_type" , 0 )
100107 shape = kwargs .get ("shape" , None )
108+ name = self ._clean_result_name (name )
101109 if itype == 0 :
102110 inp = name or "X"
103111 else :
@@ -135,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
135143
136144 def _emit_output (self , ** kwargs : Dict [str , Any ]) -> List [str ]:
137145 name = kwargs ["name" ]
146+ name = self ._clean_result_name (name )
138147 itype = kwargs .get ("elem_type" , 0 )
139148 shape = kwargs .get ("shape" , None )
140149 self .outputs .append (name )
@@ -158,16 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
158167 raise NotImplementedError ("Graph attribute not supported yet." )
159168 args .append (f"{ k } ={ vatt } " )
160169
161- outs = ", " .join (outputs )
162- inps = ", " .join (inputs )
170+ outs = ", " .join (map ( self . _clean_result_name , outputs ) )
171+ inps = ", " .join (map ( self . _clean_result_name , inputs ) )
163172 op_type = self ._emit_node_type (op_type , domain )
164173 sdomain = "" if not domain else f", domain={ domain !r} "
165174 if args :
166175 sargs = ", " .join (args )
167- row = f" { outs } = op.{ op_type } ({ inps } , { sargs } { sdomain } )"
176+ if inps :
177+ row = f" { outs } = op.{ op_type } ({ inps } , { sargs } { sdomain } )"
178+ else :
179+ row = f" { outs } = op.{ op_type } ({ sargs } { sdomain } )"
168180 else :
169181 row = f" { outs } = op.{ op_type } ({ inps } { sdomain } )"
170182 return [row ]
171183
184+ def _clean_result_name (self , name ):
185+ return name
186+
172187 def _emit_node_type (self , op_type , domain ):
173188 return op_type
0 commit comments