Skip to content

Commit 6deb5ad

Browse files
authored
Merge pull request #519 from tensorflow/fix_nested_requires
Improved nested OP_REQUIRES_OK macros.
2 parents d2dcb96 + 8579abb commit 6deb5ad

12 files changed

+125
-53
lines changed

tensorflow_quantum/core/ops/math_ops/tfq_inner_product.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/core/lib/core/error_codes.pb.h"
2929
#include "tensorflow/core/lib/core/status.h"
3030
#include "tensorflow/core/lib/core/threadpool.h"
31+
#include "tensorflow/core/platform/mutex.h"
3132
#include "tensorflow_quantum/core/ops/parse_context.h"
3233
#include "tensorflow_quantum/core/src/util_qsim.h"
3334

@@ -86,17 +87,21 @@ class TfqInnerProductOp : public tensorflow::OpKernel {
8687
std::vector<QsimFusedCircuit> fused_circuits(programs.size(),
8788
QsimFusedCircuit({}));
8889

90+
Status parse_status = Status::OK();
91+
auto p_lock = tensorflow::mutex();
8992
auto construct_f = [&](int start, int end) {
9093
for (int i = start; i < end; i++) {
91-
OP_REQUIRES_OK(context, QsimCircuitFromProgram(
92-
programs[i], maps[i], num_qubits[i],
93-
&qsim_circuits[i], &fused_circuits[i]));
94+
Status local =
95+
QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
96+
&qsim_circuits[i], &fused_circuits[i]);
97+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
9498
}
9599
};
96100

97101
const int num_cycles = 1000;
98102
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
99103
output_dim_batch_size, num_cycles, construct_f);
104+
OP_REQUIRES_OK(context, parse_status);
100105

101106
// Construct qsim circuits for other_programs.
102107
std::vector<std::vector<QsimCircuit>> other_qsim_circuits(
@@ -114,16 +119,19 @@ class TfqInnerProductOp : public tensorflow::OpKernel {
114119
Status status = QsimCircuitFromProgram(
115120
other_programs[ii][jj], {}, num_qubits[ii],
116121
&other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]);
117-
OP_REQUIRES(context, status.ok(),
118-
tensorflow::errors::InvalidArgument(absl::StrCat(
119-
"Found symbols in other_programs.",
120-
"No symbols are allowed in these circuits.")));
122+
NESTED_FN_STATUS_SYNC(parse_status, status, p_lock);
121123
}
122124
};
123125

124126
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
125127
output_dim_batch_size * output_dim_internal_size, num_cycles,
126128
construct_f2);
129+
if (!parse_status.ok()) {
130+
OP_REQUIRES_OK(context,
131+
tensorflow::errors::InvalidArgument(absl::StrCat(
132+
"Found symbols in other_programs.",
133+
"No symbols are allowed in these circuits.")));
134+
}
127135

128136
int max_num_qubits = 0;
129137
for (const int num : num_qubits) {

tensorflow_quantum/core/ops/math_ops/tfq_inner_product_grad.cc

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/core/lib/core/error_codes.pb.h"
2929
#include "tensorflow/core/lib/core/status.h"
3030
#include "tensorflow/core/lib/core/threadpool.h"
31+
#include "tensorflow/core/platform/mutex.h"
3132
#include "tensorflow_quantum/core/ops/parse_context.h"
3233
#include "tensorflow_quantum/core/src/adj_util.h"
3334
#include "tensorflow_quantum/core/src/util_qsim.h"
@@ -111,12 +112,14 @@ class TfqInnerProductGradOp : public tensorflow::OpKernel {
111112
std::vector<std::vector<GradientOfGate>> gradient_gates(
112113
programs.size(), std::vector<GradientOfGate>({}));
113114

115+
Status parse_status = Status::OK();
116+
auto p_lock = tensorflow::mutex();
114117
auto construct_f = [&](int start, int end) {
115118
for (int i = start; i < end; i++) {
116-
OP_REQUIRES_OK(
117-
context, QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
118-
&qsim_circuits[i],
119-
&fused_circuits[i], &gate_meta[i]));
119+
Status local = QsimCircuitFromProgram(
120+
programs[i], maps[i], num_qubits[i], &qsim_circuits[i],
121+
&fused_circuits[i], &gate_meta[i]);
122+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
120123

121124
CreateGradientCircuit(qsim_circuits[i], gate_meta[i],
122125
&partial_fused_circuits[i], &gradient_gates[i]);
@@ -126,6 +129,7 @@ class TfqInnerProductGradOp : public tensorflow::OpKernel {
126129
const int num_cycles = 1000;
127130
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
128131
output_dim_batch_size, num_cycles, construct_f);
132+
OP_REQUIRES_OK(context, parse_status);
129133

130134
// Construct qsim circuits for other_programs.
131135
std::vector<std::vector<QsimCircuit>> other_qsim_circuits(
@@ -143,16 +147,19 @@ class TfqInnerProductGradOp : public tensorflow::OpKernel {
143147
Status status = QsimCircuitFromProgram(
144148
other_programs[ii][jj], {}, num_qubits[ii],
145149
&other_qsim_circuits[ii][jj], &other_fused_circuits[ii][jj]);
146-
OP_REQUIRES(context, status.ok(),
147-
tensorflow::errors::InvalidArgument(absl::StrCat(
148-
"Found symbols in other_programs.",
149-
"No symbols are allowed in these circuits.")));
150+
NESTED_FN_STATUS_SYNC(parse_status, status, p_lock);
150151
}
151152
};
152153

153154
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
154155
output_dim_batch_size * output_dim_internal_size, num_cycles,
155156
construct_f2);
157+
if (!parse_status.ok()) {
158+
OP_REQUIRES_OK(context,
159+
tensorflow::errors::InvalidArgument(absl::StrCat(
160+
"Found symbols in other_programs.",
161+
"No symbols are allowed in these circuits.")));
162+
}
156163

157164
int max_num_qubits = 0;
158165
for (const int num : num_qubits) {

tensorflow_quantum/core/ops/noise/tfq_noisy_expectation.cc

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,17 +107,20 @@ class TfqNoisyExpectationOp : public tensorflow::OpKernel {
107107
std::vector<NoisyQsimCircuit> qsim_circuits(programs.size(),
108108
NoisyQsimCircuit());
109109

110+
Status parse_status = Status::OK();
111+
auto p_lock = tensorflow::mutex();
110112
auto construct_f = [&](int start, int end) {
111113
for (int i = start; i < end; i++) {
112-
OP_REQUIRES_OK(context, NoisyQsimCircuitFromProgram(
113-
programs[i], maps[i], num_qubits[i], false,
114-
&qsim_circuits[i]));
114+
Status local = NoisyQsimCircuitFromProgram(
115+
programs[i], maps[i], num_qubits[i], false, &qsim_circuits[i]);
116+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
115117
}
116118
};
117119

118120
const int num_cycles = 1000;
119121
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
120122
programs.size(), num_cycles, construct_f);
123+
OP_REQUIRES_OK(context, parse_status);
121124

122125
int max_num_qubits = 0;
123126
for (const int num : num_qubits) {
@@ -262,6 +265,9 @@ class TfqNoisyExpectationOp : public tensorflow::OpKernel {
262265
BalanceTrajectory(num_samples, num_threads, &rep_offsets);
263266

264267
output_tensor->setZero();
268+
269+
Status compute_status = Status::OK();
270+
auto c_lock = tensorflow::mutex();
265271
auto DoWork = [&](int start, int end) {
266272
// Begin simulation.
267273
const auto tfq_for = qsim::SequentialFor(1);
@@ -315,9 +321,11 @@ class TfqNoisyExpectationOp : public tensorflow::OpKernel {
315321
continue;
316322
}
317323
float exp_v = 0.0;
318-
OP_REQUIRES_OK(context,
319-
ComputeExpectationQsim(pauli_sums[i][j], sim, ss, sv,
320-
scratch, &exp_v));
324+
NESTED_FN_STATUS_SYNC(
325+
compute_status,
326+
ComputeExpectationQsim(pauli_sums[i][j], sim, ss, sv, scratch,
327+
&exp_v),
328+
c_lock);
321329
rolling_sums[j] += static_cast<double>(exp_v);
322330
run_samples[j]++;
323331
}
@@ -351,6 +359,7 @@ class TfqNoisyExpectationOp : public tensorflow::OpKernel {
351359
absl::nullopt, 1);
352360
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
353361
num_threads, scheduling_params, DoWork);
362+
OP_REQUIRES_OK(context, compute_status);
354363
}
355364
};
356365

tensorflow_quantum/core/ops/noise/tfq_noisy_samples.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,7 @@ class TfqNoisySamplesOp : public tensorflow::OpKernel {
8585
for (int i = start; i < end; i++) {
8686
auto r = NoisyQsimCircuitFromProgram(
8787
programs[i], maps[i], num_qubits[i], true, &qsim_circuits[i]);
88-
if (!r.ok()) {
89-
p_lock.lock();
90-
parse_status = r;
91-
p_lock.unlock();
92-
return;
93-
}
88+
NESTED_FN_STATUS_SYNC(parse_status, r, p_lock);
9489
}
9590
};
9691

tensorflow_quantum/core/ops/parse_context.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,15 @@ limitations under the License.
1616
#ifndef TFQ_CORE_OPS_PARSE_CONTEXT
1717
#define TFQ_CORE_OPS_PARSE_CONTEXT
1818

19+
// Syncs a threads work status with some global status.
20+
#define NESTED_FN_STATUS_SYNC(global_status, local_status, global_lock) \
21+
if (TF_PREDICT_FALSE(!local_status.ok())) { \
22+
global_lock.lock(); \
23+
global_status = local_status; \
24+
global_lock.unlock(); \
25+
return; \
26+
}
27+
1928
#include <string>
2029
#include <vector>
2130

tensorflow_quantum/core/ops/tfq_adj_grad_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/core/lib/core/error_codes.pb.h"
2929
#include "tensorflow/core/lib/core/status.h"
3030
#include "tensorflow/core/lib/core/threadpool.h"
31+
#include "tensorflow/core/platform/mutex.h"
3132
#include "tensorflow_quantum/core/ops/parse_context.h"
3233
#include "tensorflow_quantum/core/proto/pauli_sum.pb.h"
3334
#include "tensorflow_quantum/core/src/adj_util.h"
@@ -98,12 +99,14 @@ class TfqAdjointGradientOp : public tensorflow::OpKernel {
9899
std::vector<std::vector<GradientOfGate>> gradient_gates(
99100
programs.size(), std::vector<GradientOfGate>({}));
100101

102+
Status parse_status = Status::OK();
103+
auto p_lock = tensorflow::mutex();
101104
auto construct_f = [&](int start, int end) {
102105
for (int i = start; i < end; i++) {
103-
OP_REQUIRES_OK(
104-
context, QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
105-
&qsim_circuits[i], &full_fuse[i],
106-
&gate_meta[i]));
106+
Status local = QsimCircuitFromProgram(programs[i], maps[i],
107+
num_qubits[i], &qsim_circuits[i],
108+
&full_fuse[i], &gate_meta[i]);
109+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
107110
CreateGradientCircuit(qsim_circuits[i], gate_meta[i],
108111
&partial_fused_circuits[i], &gradient_gates[i]);
109112
}
@@ -112,6 +115,7 @@ class TfqAdjointGradientOp : public tensorflow::OpKernel {
112115
const int num_cycles = 1000;
113116
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
114117
programs.size(), num_cycles, construct_f);
118+
OP_REQUIRES_OK(context, parse_status);
115119

116120
// Get downstream gradients.
117121
std::vector<std::vector<float>> downstream_grads;

tensorflow_quantum/core/ops/tfq_calculate_unitary_op.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ limitations under the License.
2626
#include "tensorflow/core/lib/core/error_codes.pb.h"
2727
#include "tensorflow/core/lib/core/status.h"
2828
#include "tensorflow/core/lib/core/threadpool.h"
29+
#include "tensorflow/core/platform/mutex.h"
2930
#include "tensorflow_quantum/core/ops/parse_context.h"
3031
#include "tensorflow_quantum/core/src/circuit_parser_qsim.h"
3132
#include "tensorflow_quantum/core/src/util_qsim.h"
@@ -67,17 +68,21 @@ class TfqCalculateUnitaryOp : public tensorflow::OpKernel {
6768
std::vector<std::vector<qsim::GateFused<QsimGate>>> fused_circuits(
6869
programs.size(), std::vector<qsim::GateFused<QsimGate>>({}));
6970

71+
Status parse_status = Status::OK();
72+
auto p_lock = tensorflow::mutex();
7073
auto construct_f = [&](int start, int end) {
7174
for (int i = start; i < end; i++) {
72-
OP_REQUIRES_OK(context, QsimCircuitFromProgram(
73-
programs[i], maps[i], num_qubits[i],
74-
&qsim_circuits[i], &fused_circuits[i]));
75+
Status local =
76+
QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
77+
&qsim_circuits[i], &fused_circuits[i]);
78+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
7579
}
7680
};
7781

7882
const int num_cycles = 1000;
7983
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
8084
programs.size(), num_cycles, construct_f);
85+
OP_REQUIRES_OK(context, parse_status);
8186

8287
// Find largest circuit for tensor size padding and allocate
8388
// the output tensor.

tensorflow_quantum/core/ops/tfq_resolve_parameters_op.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ limitations under the License.
2222
#include "tensorflow/core/lib/core/error_codes.pb.h"
2323
#include "tensorflow/core/lib/core/status.h"
2424
#include "tensorflow/core/lib/core/threadpool.h"
25+
#include "tensorflow/core/platform/mutex.h"
2526
#include "tensorflow_quantum/core/ops/parse_context.h"
2627
#include "tensorflow_quantum/core/ops/tfq_simulate_utils.h"
2728
#include "tensorflow_quantum/core/src/program_resolution.h"
@@ -58,11 +59,14 @@ class TfqResolveParametersOp : public tensorflow::OpKernel {
5859
"Number of circuits and values do not match. Got ", programs.size(),
5960
" circuits and ", maps.size(), " values.")));
6061

62+
Status parse_status = Status::OK();
63+
auto p_lock = tensorflow::mutex();
6164
auto DoWork = [&](int start, int end) {
6265
std::string temp;
6366
for (int i = start; i < end; i++) {
6467
Program program = programs[i];
65-
OP_REQUIRES_OK(context, ResolveSymbols(maps[i], &program, false));
68+
Status local = ResolveSymbols(maps[i], &program, false);
69+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
6670
program.SerializeToString(&temp);
6771
output_tensor(i) = temp;
6872
}
@@ -71,6 +75,7 @@ class TfqResolveParametersOp : public tensorflow::OpKernel {
7175
const int num_cycles = 1000;
7276
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
7377
programs.size(), num_cycles, DoWork);
78+
OP_REQUIRES_OK(context, parse_status);
7479
}
7580
};
7681

tensorflow_quantum/core/ops/tfq_simulate_expectation_op.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
#include "tensorflow/core/lib/core/error_codes.pb.h"
2929
#include "tensorflow/core/lib/core/status.h"
3030
#include "tensorflow/core/lib/core/threadpool.h"
31+
#include "tensorflow/core/platform/mutex.h"
3132
#include "tensorflow_quantum/core/ops/parse_context.h"
3233
#include "tensorflow_quantum/core/proto/pauli_sum.pb.h"
3334
#include "tensorflow_quantum/core/src/util_qsim.h"
@@ -85,17 +86,21 @@ class TfqSimulateExpectationOp : public tensorflow::OpKernel {
8586
std::vector<std::vector<qsim::GateFused<QsimGate>>> fused_circuits(
8687
programs.size(), std::vector<qsim::GateFused<QsimGate>>({}));
8788

89+
Status parse_status = Status::OK();
90+
auto p_lock = tensorflow::mutex();
8891
auto construct_f = [&](int start, int end) {
8992
for (int i = start; i < end; i++) {
90-
OP_REQUIRES_OK(context, QsimCircuitFromProgram(
91-
programs[i], maps[i], num_qubits[i],
92-
&qsim_circuits[i], &fused_circuits[i]));
93+
Status local =
94+
QsimCircuitFromProgram(programs[i], maps[i], num_qubits[i],
95+
&qsim_circuits[i], &fused_circuits[i]);
96+
NESTED_FN_STATUS_SYNC(parse_status, local, p_lock);
9397
}
9498
};
9599

96100
const int num_cycles = 1000;
97101
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
98102
programs.size(), num_cycles, construct_f);
103+
OP_REQUIRES_OK(context, parse_status);
99104

100105
int max_num_qubits = 0;
101106
for (const int num : num_qubits) {
@@ -181,6 +186,8 @@ class TfqSimulateExpectationOp : public tensorflow::OpKernel {
181186

182187
const int output_dim_op_size = output_tensor->dimension(1);
183188

189+
Status compute_status = Status::OK();
190+
auto c_lock = tensorflow::mutex();
184191
auto DoWork = [&](int start, int end) {
185192
int old_batch_index = -2;
186193
int cur_batch_index = -1;
@@ -220,9 +227,11 @@ class TfqSimulateExpectationOp : public tensorflow::OpKernel {
220227
}
221228

222229
float exp_v = 0.0;
223-
OP_REQUIRES_OK(context, ComputeExpectationQsim(
224-
pauli_sums[cur_batch_index][cur_op_index],
225-
sim, ss, sv, scratch, &exp_v));
230+
NESTED_FN_STATUS_SYNC(
231+
compute_status,
232+
ComputeExpectationQsim(pauli_sums[cur_batch_index][cur_op_index],
233+
sim, ss, sv, scratch, &exp_v),
234+
c_lock);
226235
(*output_tensor)(cur_batch_index, cur_op_index) = exp_v;
227236
old_batch_index = cur_batch_index;
228237
}
@@ -232,6 +241,7 @@ class TfqSimulateExpectationOp : public tensorflow::OpKernel {
232241
200 * (int64_t(1) << static_cast<int64_t>(max_num_qubits));
233242
context->device()->tensorflow_cpu_worker_threads()->workers->ParallelFor(
234243
fused_circuits.size() * output_dim_op_size, num_cycles, DoWork);
244+
OP_REQUIRES_OK(context, compute_status);
235245
}
236246
};
237247

0 commit comments

Comments
 (0)