4848tf_metatize_cache = Cache (50 )
4949
5050
51+ class DefaultTensorName (str ):
52+ """A type used to indicate a default tensor name."""
53+
54+ pass
55+
56+
5157class MetaOpDefLibrary (object ):
5258 """A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.
5359
@@ -366,10 +372,16 @@ def _protobuf_convert(cls, k, v):
366372 raise TypeError (f"Could not convert { k } " )
367373
368374 def __init__ (self , op , name , attr , obj = None ):
375+ """Create a TF meta NodeDef.
376+
377+ XXX: Meta NodeDefs with `name == None` have a special meaning;
378+ their names are uniquely generated. We still consider them equal
379+ (when every other property is equal, of course).
380+ """
369381 super ().__init__ (obj = obj )
370382 self .op = metatize (op )
371383 assert name is not None
372- self .name = name if isvar (name ) else str ( name )
384+ self .name = name if isvar (name ) else name
373385
374386 if not isvar (attr ):
375387 opdef_sig , _ = op_def_lib .get_op_info (self .op )
@@ -600,6 +612,11 @@ def reify(self):
600612 # An operation with this name might already exist in the graph
601613 #
602614 try :
615+ # FIXME: Lame hack
616+ if isinstance (self .name , DefaultTensorName ):
617+ # Use a unique version of the default name.
618+ raise KeyError ()
619+
603620 existing_op = ops .get_default_graph ().get_operation_by_name (self .name )
604621 except KeyError :
605622 #
@@ -613,7 +630,15 @@ def reify(self):
613630 # An `Operation` with this name exists, let's make sure it's
614631 # equivalent to this meta `Operation`
615632 #
616- if self != mt (existing_op ):
633+ existing_op_mt = mt (existing_op )
634+
635+ # # Since we can't exactly reproduce all NodeDef.attr information
636+ # # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
637+ # # fields from comparisons with same-named nodes in the graph.
638+ # if op_attrs.keys() != node_attr.keys():
639+ # existing_op_mt.node_def.attr = node_attr
640+
641+ if self != existing_op_mt :
617642 raise MetaReificationError (
618643 f"An Operation with the name { self .name } "
619644 " already exists in the graph and is not"
@@ -725,40 +750,40 @@ def reify(self):
725750
726751 def __truediv__ (self , y ):
727752 # TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
728- return mt .realdiv (self , y , name = "truediv" )
753+ return mt .realdiv (self , y , name = DefaultTensorName ( "truediv" ) )
729754
730755 def __rtruediv__ (self , x ):
731756 # TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
732- return mt .realdiv (x , self , name = "truediv" )
757+ return mt .realdiv (x , self , name = DefaultTensorName ( "truediv" ) )
733758
734759 def __add__ (self , y ):
735760 # TODO: If `self.dtype == tf.dtypes.string`, use `mt.add`
736- return mt .addv2 (self , y , name = "add" )
761+ return mt .addv2 (self , y , name = DefaultTensorName ( "add" ) )
737762
738763 def __radd__ (self , x ):
739764 # TODO: If `x.dtype == tf.dtypes.string`, use `mt.add`
740- return mt .addv2 (x , self , name = "add" )
765+ return mt .addv2 (x , self , name = DefaultTensorName ( "add" ) )
741766
742767 def __sub__ (self , y ):
743- return mt .sub (self , y , name = "sub" )
768+ return mt .sub (self , y , name = DefaultTensorName ( "sub" ) )
744769
745770 def __rsub__ (self , x ):
746- return mt .sub (x , self , name = "sub" )
771+ return mt .sub (x , self , name = DefaultTensorName ( "sub" ) )
747772
748773 def __mul__ (self , y ):
749- return mt .mul (self , y , name = "mul" )
774+ return mt .mul (self , y , name = DefaultTensorName ( "mul" ) )
750775
751776 def __rmul__ (self , x ):
752- return mt .mul (x , self , name = "mul" )
777+ return mt .mul (x , self , name = DefaultTensorName ( "mul" ) )
753778
754779 def __abs__ (self ):
755- return mt .abs (self , name = "Abs" )
780+ return mt .abs (self , name = DefaultTensorName ( "Abs" ) )
756781
757782 def __pow__ (self , y ):
758- return mt .pow (self , y , name = "pow" )
783+ return mt .pow (self , y , name = DefaultTensorName ( "pow" ) )
759784
760785 def __neg__ (self ):
761- return mt .neg (self , name = "Neg" )
786+ return mt .neg (self , name = DefaultTensorName ( "Neg" ) )
762787
763788
764789class TFlowMetaTensorShape (TFlowMetaSymbol ):
@@ -987,48 +1012,22 @@ def __api_call__(self, *args, **kwargs):
9871012
9881013 if not op_args_unreified :
9891014
990- res_var = None
991- # name = op_args.get("name", None)
992- #
993- # if name is not None:
994- # #
995- # # An operation with this name might already exist in the graph
996- # #
9971015 #
998- # from tensorflow.python.framework import ops
1016+ # We create the `Operation` in the graph
9991017 #
1000- # try:
1001- # this_op = ops.get_default_graph().get_operation_by_name(name)
1002- # except KeyError:
1003- # pass
1004- # else:
1005- # # TODO: Make sure the existing `Operation` matches our arguments
1006- # assert this_op.type == self.op_def.obj.name
1007- #
1008- # this_op = mt(this_op)
1009- # op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
1010- # assert op_inputs == this_op.inputs
1011- # assert op_node_def == this_op.node_def
1012- # res_var = this_op.default_output
1013-
1014- if res_var is None :
1015- #
1016- # We create the `Operation` in the graph
1017- #
1018-
1019- tf_out = self ._apply_func (** op_args )
1020-
1021- # Ensure that the original meta objects will be available
1022- # for use in the `metatize` that follows
1023- tf_metatize_cache .update (
1024- {
1025- k : v
1026- for k , v in zip (op_args .values (), apply_arguments .values ())
1027- if isinstance (k , tf .Tensor )
1028- }
1029- )
1018+ tf_out = self ._apply_func (** op_args )
1019+
1020+ # Ensure that the original meta objects will be available
1021+ # for use in the `metatize` that follows
1022+ tf_metatize_cache .update (
1023+ {
1024+ k : v
1025+ for k , v in zip (op_args .values (), apply_arguments .values ())
1026+ if isinstance (k , tf .Tensor )
1027+ }
1028+ )
10301029
1031- res_var = metatize (tf_out )
1030+ res_var = metatize (tf_out )
10321031
10331032 if "names" in meta ._lvar_defaults_enabled :
10341033 # This should also reset the NodeDef's `obj`
@@ -1073,7 +1072,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
10731072 node_attr = var ()
10741073
10751074 if "names" not in meta ._lvar_defaults_enabled :
1076- op_name = apply_arguments .get ("name" , op_def_tf .name ) or op_def_tf .name
1075+ default_name = DefaultTensorName (op_def_tf .name )
1076+ op_name = apply_arguments .get ("name" , default_name ) or default_name
10771077 else :
10781078 op_name = var ()
10791079
0 commit comments