@@ -28,6 +28,7 @@ limitations under the License.
2828#include " tensorflow/core/lib/core/error_codes.pb.h"
2929#include " tensorflow/core/lib/core/status.h"
3030#include " tensorflow/core/lib/core/threadpool.h"
31+ #include " tensorflow/core/platform/mutex.h"
3132#include " tensorflow_quantum/core/ops/parse_context.h"
3233#include " tensorflow_quantum/core/src/adj_util.h"
3334#include " tensorflow_quantum/core/src/util_qsim.h"
@@ -111,12 +112,14 @@ class TfqInnerProductGradOp : public tensorflow::OpKernel {
111112 std::vector<std::vector<GradientOfGate>> gradient_gates (
112113 programs.size (), std::vector<GradientOfGate>({}));
113114
115+ Status parse_status = Status::OK ();
116+ auto p_lock = tensorflow::mutex ();
114117 auto construct_f = [&](int start, int end) {
115118 for (int i = start; i < end; i++) {
116- OP_REQUIRES_OK (
117- context, QsimCircuitFromProgram ( programs[i], maps[i], num_qubits[i],
118- &qsim_circuits [i],
119- &fused_circuits[i], &gate_meta[i]) );
119+ Status local = QsimCircuitFromProgram (
120+ programs[i], maps[i], num_qubits[i], &qsim_circuits [i],
121+ &fused_circuits [i], &gate_meta[i]);
122+ NESTED_FN_STATUS_SYNC (parse_status, local, p_lock );
120123
121124 CreateGradientCircuit (qsim_circuits[i], gate_meta[i],
122125 &partial_fused_circuits[i], &gradient_gates[i]);
@@ -126,6 +129,7 @@ class TfqInnerProductGradOp : public tensorflow::OpKernel {
126129 const int num_cycles = 1000 ;
127130 context->device ()->tensorflow_cpu_worker_threads ()->workers ->ParallelFor (
128131 output_dim_batch_size, num_cycles, construct_f);
132+ OP_REQUIRES_OK (context, parse_status);
129133
130134 // Construct qsim circuits for other_programs.
131135 std::vector<std::vector<QsimCircuit>> other_qsim_circuits (
@@ -143,16 +147,19 @@ class TfqInnerProductGradOp : public tensorflow::OpKernel {
143147 Status status = QsimCircuitFromProgram (
144148 other_programs[ii][jj], {}, num_qubits[ii],
145149 &other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]);
146- OP_REQUIRES (context, status.ok (),
147- tensorflow::errors::InvalidArgument (absl::StrCat (
148- " Found symbols in other_programs." ,
149- " No symbols are allowed in these circuits." )));
150+ NESTED_FN_STATUS_SYNC (parse_status, status, p_lock);
150151 }
151152 };
152153
153154 context->device ()->tensorflow_cpu_worker_threads ()->workers ->ParallelFor (
154155 output_dim_batch_size * output_dim_internal_size, num_cycles,
155156 construct_f2);
157+ if (!parse_status.ok ()) {
158+ OP_REQUIRES_OK (context,
159+ tensorflow::errors::InvalidArgument (absl::StrCat (
160+ " Found symbols in other_programs." ,
161+ " No symbols are allowed in these circuits." )));
162+ }
156163
157164 int max_num_qubits = 0 ;
158165 for (const int num : num_qubits) {
0 commit comments