1010 with tc .runtime_backend (backend ) as K :
1111 L = 2
1212 inputs = K .cast (K .ones ([3 , 2 ]), tc .rdtypestr )
13- weights = K .cast (K .ones ([2 ]), tc .rdtypestr )
13+ weights = K .cast (K .ones ([3 , 2 ]), tc .rdtypestr )
1414
1515 def ansatz (thetas , alpha ):
1616 c = tc .Circuit (L )
@@ -27,24 +27,42 @@ def f(thetas, alpha):
2727 observables = K .stack ([K .real (c .expectation_ps (z = [i ])) for i in range (L )])
2828 return K .mean (observables )
2929
30- # f_vmap = K.vmap(f, vectorized_argnums=0 )
31-
32- print ("grad " , K .grad (f )(inputs [ 0 ] , weights ))
33- print ("vmap " , K .vmap (f )(inputs , weights ))
34- print ("vmap over grad " , K .vmap (K .grad (f ))(inputs , weights ))
30+ print ( "grad_0" , K . grad ( f )( inputs [ 0 ], weights [ 0 ]) )
31+ print ( "grad_1" , K . grad ( f , argnums = 1 )( inputs [ 0 ], weights [ 0 ]))
32+ print ("vmap_0 " , K .vmap (f )(inputs , weights [ 0 ] ))
33+ print ("vmap_1 " , K .vmap (f , vectorized_argnums = 1 )(inputs [ 0 ] , weights ))
34+ print ("vmap over grad_0 " , K .vmap (K .grad (f ))(inputs , weights [ 0 ] ))
3535 # wrong in tf due to https://github.com/google/TensorNetwork/issues/940
3636 # https://github.com/tensorflow/tensorflow/issues/52148
37- print ("vmap over jacfwd" , K .vmap (K .jacfwd (f ))(inputs , weights ))
38- print ("jacfwd over vmap" , K .jacfwd (K .vmap (f ))(inputs , weights ))
39- r = K .vmap (K .jacrev (f ))(inputs , weights )
40- print ("vmap over jacrev" , r )
37+ print ("vmap over grad_1" , K .vmap (K .grad (f , argnums = 1 ))(inputs , weights [0 ]))
38+ # wrong in tf
39+ print ("vmap over jacfwd_0" , K .vmap (K .jacfwd (f ))(inputs , weights [0 ]))
40+ print ("jacfwd_0 over vmap" , K .jacfwd (K .vmap (f ))(inputs , weights [0 ]))
41+ print ("vmap over jacfwd_1" , K .vmap (K .jacfwd (f , argnums = 1 ))(inputs , weights [0 ]))
42+ print ("jacfwd_1 over vmap" , K .jacfwd (K .vmap (f ), argnums = 1 )(inputs , weights [0 ]))
43+ r = K .vmap (K .jacrev (f ))(inputs , weights [0 ])
44+ print ("vmap over jacrev0" , r )
45+ # wrong in tf
46+ r = K .jacrev (K .vmap (f ))(inputs , weights [0 ])
47+ print ("jacrev0 over vmap" , r )
48+ r = K .vmap (K .jacrev (f , argnums = 1 ))(inputs , weights [0 ])
49+ print ("vmap over jacrev1" , r )
50+ # wrong in tf
51+ r = K .jacrev (K .vmap (f ), argnums = 1 )(inputs , weights [0 ])
52+ print ("jacrev1 over vmap" , r )
53+ r = K .vmap (K .jacrev (f , argnums = 1 ), vectorized_argnums = 1 )(inputs [0 ], weights )
54+ print ("vmap1 over jacrev1" , r )
55+ r = K .jacrev (K .vmap (f , vectorized_argnums = 1 ), argnums = 1 )(inputs [0 ], weights )
56+ print ("jacrev1 over vmap1" , r )
57+ r = K .vmap (K .hessian (f ))(inputs , weights [0 ])
58+ print ("vmap over hess0" , r )
4159 # wrong in tf
42- r = K .jacrev (K .vmap (f ))(inputs , weights )
43- print ("jacrev over vmap" , r )
44- r = K .vmap (K .hessian (f ))(inputs , weights )
45- print ("vmap over hess " , r )
60+ r = K .hessian (K .vmap (f ))(inputs , weights [ 0 ] )
61+ print ("hess0 over vmap" , r )
62+ r = K .vmap (K .hessian (f , argnums = 1 ))(inputs , weights [ 0 ] )
63+ print ("vmap over hess1 " , r )
4664 # wrong in tf
47- r = K .hessian (K .vmap (f ))(inputs , weights )
48- print ("hess over vmap" , r )
65+ r = K .hessian (K .vmap (f ), argnums = 1 )(inputs , weights [ 0 ] )
66+ print ("hess1 over vmap" , r )
4967
5068# lessons: never put vmap outside gradient function in tf
0 commit comments