@@ -43,7 +43,7 @@ def fixed_hessian(point, model=None):
4343 return rval
4444
4545
46- def find_hessian (point , vars = None , model = None ):
46+ def find_hessian (point , vars = None , model = None , negate_output = True ):
4747 """
4848 Returns Hessian of logp at the point passed.
4949
@@ -55,11 +55,11 @@ def find_hessian(point, vars=None, model=None):
5555 Variables for which Hessian is to be calculated.
5656 """
5757 model = modelcontext (model )
58- H = model .compile_d2logp (vars )
58+ H = model .compile_d2logp (vars , negate_output = negate_output )
5959 return H (Point (point , filter_model_vars = True , model = model ))
6060
6161
62- def find_hessian_diag (point , vars = None , model = None ):
62+ def find_hessian_diag (point , vars = None , model = None , negate_output = True ):
6363 """
6464 Returns Hessian of logp at the point passed.
6565
@@ -71,14 +71,14 @@ def find_hessian_diag(point, vars=None, model=None):
7171 Variables for which Hessian is to be calculated.
7272 """
7373 model = modelcontext (model )
74- H = model .compile_fn (hessian_diag (model .logp (), vars ))
74+ H = model .compile_fn (hessian_diag (model .logp (), vars , negate_output = negate_output ))
7575 return H (Point (point , model = model ))
7676
7777
7878def guess_scaling (point , vars = None , model = None , scaling_bound = 1e-8 ):
7979 model = modelcontext (model )
8080 try :
81- h = find_hessian_diag (point , vars , model = model )
81+ h = - find_hessian_diag (point , vars , model = model , negate_output = False )
8282 except NotImplementedError :
8383 h = fixed_hessian (point , model = model )
8484 return adjust_scaling (h , scaling_bound )
0 commit comments