1616
1717from google .protobuf .message import Message
1818
19- from tensorflow .python .framework import tensor_util , op_def_registry , op_def_library , tensor_shape
19+ from tensorflow .python .framework import (
20+ tensor_util ,
21+ op_def_registry ,
22+ op_def_library ,
23+ tensor_shape ,
24+ ops ,
25+ )
2026from tensorflow .core .framework .op_def_pb2 import OpDef
2127from tensorflow .core .framework .node_def_pb2 import NodeDef
2228
23- # from tensorflow.core.framework.tensor_shape_pb2 import TensorShapeProto
24-
2529from tensorflow_probability import distributions as tfd
2630
2731
3034 MetaSymbolType ,
3135 MetaOp ,
3236 MetaVariable ,
37+ MetaReificationError ,
3338 meta_reify_iter ,
3439 _metatize ,
3540 metatize ,
@@ -60,52 +65,68 @@ class MetaOpDefLibrary(object):
6065 }
6166 opdef_signatures = {}
6267
68+ def __init__ (self ):
69+ #
70+ # We need this in order to construct "Const" tensors directly, since
71+ # the "value" attr in a meta `NodeDef` is just a NumPy array and not
72+ # the `TensorProto` expected by `raw_ops.Const`.
73+ #
74+ def mt_const (value , dtype , name = None ):
75+ return tf .raw_ops .Const (
76+ value = tensor_util .make_tensor_proto (value ), dtype = dtype , name = name
77+ )
78+
79+ opdef = op_def_registry .get ("Const" )
80+ self .opdef_signatures [opdef .name ] = self .make_opdef_sig (opdef , mt_const )
81+
6382 @classmethod
64- def apply_op (cls , * args , ** kwargs ):
65- return op_def_library .apply_op (* args , ** kwargs )
83+ def get_op_info (cls , opdef ):
84+ """Return the TF Python API function signature for a given `OpDef`.
85+
86+ Parameter
87+ ---------
88+ opdef: str or `OpDef` object (meta or base)
89+ """
90+ if isinstance (opdef , str ):
91+ opdef_name = opdef
92+ opdef = op_def_registry .get (opdef_name )
93+ else :
94+ opdef_name = opdef .name
95+
96+ opdef_sig = cls .opdef_signatures .get (opdef_name , None )
97+
98+ if opdef_sig is None and opdef is not None :
99+ opdef_func = getattr (tf .raw_ops , opdef .name , None )
100+ opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
101+ cls .opdef_signatures [opdef .name ] = opdef_sig
102+
103+ return opdef_sig
66104
67105 @classmethod
68106 def make_opdef_sig (cls , opdef , opdef_py_func = None ):
69107 """Create a `Signature` object for an `OpDef`.
70108
71109 Annotations are include so that one can partially verify arguments.
72110 """
73- input_args = OrderedDict ([(a .name , a .type or a .type_attr ) for a in opdef .input_arg ])
74- attrs = OrderedDict ([(a .name , a ) for a in opdef .attr ])
75-
76- params = OrderedDict ()
77111 if opdef_py_func :
112+ #
78113 # We assume we're dealing with a function from `tf.raw_ops`.
79- # Those functions have only the necessary `input_arg`s and
80- # `attr` inputs as arguments.
114+ # Those functions have only the necessary `input_arg`s and `attr`
115+ # inputs as arguments.
116+ #
81117 opdef_func_sig = Signature .from_callable (opdef_py_func )
82118 params = opdef_func_sig .parameters
83119
84- # for name, param in opdef_func_sig.parameters.items():
85- # # We make positional parameters permissible (since the
86- # # functions in `tf.raw_ops` are keyword-only), and we use the
87- # # `tf.raw_ops` arguments to determine the *actual* required
88- # # arguments (because `OpDef`'s `input_arg`s and `attrs` aren't
89- # # exactly clear about that).
90- # if name in input_args:
91- # new_default = Parameter.empty
92- # new_annotation = input_args[name]
93- # else:
94- # new_default = None
95- # new_annotation = attrs.get(name, None)
96- # if new_annotation is not None:
97- # new_annotation = new_annotation.type
120+ else :
98121 #
99- # new_param = param.replace(
100- # kind=Parameter.POSITIONAL_OR_KEYWORD,
101- # default=new_default,
102- # annotation=new_annotation,
103- # )
104- # params[name] = new_param
122+ # We're crafting an `Operation` at a low-level via `apply_op`
123+ # (like the functions in `tf.raw_ops` do)
124+ #
125+ input_args = OrderedDict ([( a . name , a . type or a . type_attr ) for a in opdef . input_arg ])
126+ attrs = OrderedDict ([( a . name , a ) for a in opdef . attr ] )
127+ params = OrderedDict ()
105128
106- else :
107- # We're crafting the Operation at a low-level via `apply_op`.
108- opdef_py_func = partial (op_def_lib .apply_op , opdef .name )
129+ opdef_py_func = partial (op_def_library .apply_op , opdef .name )
109130
110131 for i_name , i_type in input_args .items ():
111132 p = Parameter (i_name , Parameter .POSITIONAL_OR_KEYWORD , annotation = i_type )
@@ -144,29 +165,6 @@ def make_opdef_sig(cls, opdef, opdef_py_func=None):
144165 )
145166 return opdef_sig , opdef_py_func
146167
147- @classmethod
148- def get_op_info (cls , opdef ):
149- """Return the TF Python API function signature for a given `OpDef`.
150-
151- Parameter
152- ---------
153- opdef: str or `OpDef` object (meta or base)
154- """
155- if isinstance (opdef , str ):
156- opdef_name = opdef
157- opdef = op_def_registry .get (opdef_name )
158- else :
159- opdef_name = opdef .name
160-
161- opdef_sig = cls .opdef_signatures .get (opdef_name , None )
162-
163- if opdef_sig is None and opdef is not None :
164- opdef_func = getattr (tf .raw_ops , opdef .name , None )
165- opdef_sig = cls .make_opdef_sig (opdef , opdef_func )
166- cls .opdef_signatures [opdef .name ] = cls .make_opdef_sig (opdef , opdef_func )
167-
168- return opdef_sig
169-
170168
171169op_def_lib = MetaOpDefLibrary ()
172170
@@ -183,7 +181,6 @@ def _metatize_tf_object(obj):
183181def load_dispatcher ():
184182 """Set/override dispatcher to default to TF objects."""
185183
186- from tensorflow .python .framework .ops import EagerTensor
187184 from tensorflow .python .ops .gen_linalg_ops import _SvdOutput
188185
189186 def _metatize_tf_svd (obj ):
@@ -200,7 +197,7 @@ def _metatize_tf_eager(obj):
200197 " (e.g. within `tensorflow.python.eager.context.graph_mode`)"
201198 )
202199
203- meta ._metatize .add ((EagerTensor ,), _metatize_tf_eager )
200+ meta ._metatize .add ((ops . EagerTensor ,), _metatize_tf_eager )
204201
205202 meta ._metatize .add ((object ,), _metatize_tf_object )
206203 meta ._metatize .add ((HashableNDArray ,), _metatize_tf_object )
@@ -599,12 +596,30 @@ def reify(self):
599596 )
600597
601598 if not (op_inputs_unreified or op_attrs_unreified or isvar (self .name )):
602-
603- apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
604- tf_out = operator ._apply_func (** apply_arguments )
605- op_tf = tf_out .op
606-
607- # TODO: Update NodeDef attrs?
599+ #
600+ # An operation with this name might already exist in the graph
601+ #
602+ try :
603+ existing_op = ops .get_default_graph ().get_operation_by_name (self .name )
604+ except KeyError :
605+ #
606+ # There is no such `Operation`, so we attempt to create it
607+ #
608+ apply_arguments = operator .input_args (* op_inputs , name = self .name , ** op_attrs )
609+ tf_out = operator ._apply_func (** apply_arguments )
610+ op_tf = tf_out .op
611+ else :
612+ #
613+ # An `Operation` with this name exists, let's make sure it's
614+ # equivalent to this meta `Operation`
615+ #
616+ if self != mt (existing_op ):
617+ raise MetaReificationError (
618+ f"An Operation with the name { self .name } "
619+ " already exists in the graph and is not"
620+ " equal to this meta object."
621+ )
622+ op_tf = existing_op
608623
609624 assert op_tf is not None
610625 self ._obj = op_tf
@@ -1149,4 +1164,5 @@ def __getattr__(self, obj):
11491164
11501165mt = TFlowMetaAccessor ()
11511166
1167+
11521168load_dispatcher ()
0 commit comments