@@ -129,22 +129,20 @@ def create_file_and_save_alive_counts(self, base_dir: Text,
129129 global_step : int ) -> None :
130130 """Creates and updates files with alive counts.
131131
132- Creates the directory `{base_dir}/learned_structure/` and saves the current
133- alive counts to:
134- `{base_dir}/learned_structure/{ALIVE_FILENAME}_{global_step}`.
132+ Creates the directory `{base_dir}` and saves the current alive counts to:
133+ `{base_dir}/{ALIVE_FILENAME}_{global_step}`.
135134
136135 Args:
137136 base_dir: where to export the alive counts.
138137 global_step: current value of global step, used as a suffix in filename.
139138 """
140139 current_filename = '%s_%s' % (ALIVE_FILENAME , global_step )
141- directory = os .path .join (base_dir , 'learned_structure' )
142140 try :
143- tf .gfile .MakeDirs (directory )
141+ tf .gfile .MakeDirs (base_dir )
144142 except tf .errors .OpError :
145143 # Probably already exists. If not, we'll see the error in the next line.
146144 pass
147- with tf .gfile .Open (os .path .join (directory , current_filename ), 'w' ) as f :
145+ with tf .gfile .Open (os .path .join (base_dir , current_filename ), 'w' ) as f :
148146 self .save_alive_counts (f ) # pytype: disable=wrong-arg-types
149147
150148
@@ -196,3 +194,29 @@ def _compute_alive_counts(
196194
197195def format_structure (structure : Dict [Text , int ]) -> Text :
198196 return json .dumps (structure , indent = 2 , sort_keys = True , default = str )
197+
198+
199+ class StructureExporterHook (tf .train .SessionRunHook ):
200+ """Estimator hook for StructureExporter.
201+
202+ Usage:
203+ exporter = structure_exporter.StructureExporter(
204+ network_regularizer.op_regularizer_manager)
205+ structure_export_hook = structure_exporter.StructureExporterHook(
206+ exporter, '/path/to/cns')
207+ estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
208+ ...,
209+ training_hooks=[structure_export_hook])
210+ """
211+
212+ def __init__ (self , exporter : StructureExporter , export_dir : Text ):
213+ self ._export_dir = export_dir
214+ self ._exporter = exporter
215+
216+ def end (self , session : tf .Session ):
217+ global_step = session .run (tf .train .get_global_step ())
218+ tf .logging .info ('Exporting structure at step %d' , global_step )
219+ tensor_to_eval_dict = session .run (self ._exporter .tensors )
220+ self ._exporter .populate_tensor_values (session .run (tensor_to_eval_dict ))
221+ self ._exporter .create_file_and_save_alive_counts (self ._export_dir ,
222+ global_step )
0 commit comments