Skip to content

Commit 668fa62

Browse files
committed
[Fix] Mutable prange and smaller issues in corelators
1 parent be22624 commit 668fa62

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

pyerrors/correlators.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Corr:
4242

4343
__slots__ = ["content", "N", "T", "tag", "prange"]
4444

45-
def __init__(self, data_input, padding=[0, 0], prange=None):
45+
def __init__(self, data_input, padding=None, prange=None):
4646
""" Initialize a Corr object.
4747
4848
Parameters
@@ -58,6 +58,8 @@ def __init__(self, data_input, padding=[0, 0], prange=None):
5858
region identified for this correlator.
5959
"""
6060

61+
if padding is None:
62+
padding = [0, 0]
6163
if isinstance(data_input, np.ndarray):
6264
if data_input.ndim == 1:
6365
data_input = list(data_input)
@@ -105,7 +107,7 @@ def __init__(self, data_input, padding=[0, 0], prange=None):
105107
self.N = noNull[0].shape[0]
106108
if self.N > 1 and noNull[0].shape[0] != noNull[0].shape[1]:
107109
raise ValueError("Smearing matrices are not NxN.")
108-
if (not all([item.shape == noNull[0].shape for item in noNull])):
110+
if not all([item.shape == noNull[0].shape for item in noNull]):
109111
raise ValueError("Items in data_input are not of identical shape." + str(noNull))
110112
else:
111113
raise TypeError("'data_input' contains item of wrong type.")
@@ -236,7 +238,7 @@ def symmetric(self):
236238
newcontent.append(None)
237239
else:
238240
newcontent.append(0.5 * (self.content[t] + self.content[self.T - t]))
239-
if (all([x is None for x in newcontent])):
241+
if all([x is None for x in newcontent]):
240242
raise ValueError("Corr could not be symmetrized: No redundant values")
241243
return Corr(newcontent, prange=self.prange)
242244

@@ -300,7 +302,7 @@ def matrix_symmetric(self):
300302
return 0.5 * (Corr(transposed) + self)
301303

302304
def GEVP(self, t0, ts=None, sort="Eigenvalue", vector_obs=False, **kwargs):
303-
r'''Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
305+
r"""Solve the generalized eigenvalue problem on the correlator matrix and returns the corresponding eigenvectors.
304306
305307
The eigenvectors are sorted according to the descending eigenvalues, the zeroth eigenvector(s) correspond to the
306308
largest eigenvalue(s). The eigenvector(s) for the individual states can be accessed via slicing
@@ -333,12 +335,12 @@ def GEVP(self, t0, ts=None, sort="Eigenvalue", vector_obs=False, **kwargs):
333335
Method used to solve the GEVP.
334336
- "eigh": Use scipy.linalg.eigh to solve the GEVP. (default for vector_obs=False)
335337
- "cholesky": Use manually implemented solution via the Cholesky decomposition. Automatically chosen if vector_obs==True.
336-
'''
338+
"""
337339

338340
if self.N == 1:
339341
raise ValueError("GEVP methods only works on correlator matrices and not single correlators.")
340342
if ts is not None:
341-
if (ts <= t0):
343+
if ts <= t0:
342344
raise ValueError("ts has to be larger than t0.")
343345

344346
if "sorted_list" in kwargs:
@@ -786,7 +788,7 @@ def root_function(x, d):
786788
raise ValueError('Unknown variant.')
787789

788790
def fit(self, function, fitrange=None, silent=False, **kwargs):
789-
r'''Fits function to the data
791+
r"""Fits function to the data
790792
791793
Parameters
792794
----------
@@ -799,7 +801,7 @@ def fit(self, function, fitrange=None, silent=False, **kwargs):
799801
If not specified, self.prange or all timeslices are used.
800802
silent : bool
801803
Decides whether output is printed to the standard output.
802-
'''
804+
"""
803805
if self.N != 1:
804806
raise ValueError("Correlator must be projected before fitting")
805807

@@ -878,6 +880,8 @@ def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=No
878880
comp : Corr or list of Corr
879881
Correlator or list of correlators which are plotted for comparison.
880882
The tags of these correlators are used as labels if available.
883+
y_range : list
884+
list of two values, determining the range of the y-axis e.g. [0, 12].
881885
logscale : bool
882886
Sets y-axis to logscale.
883887
plateau : Obs
@@ -1093,7 +1097,7 @@ def __eq__(self, y):
10931097

10941098
def __add__(self, y):
10951099
if isinstance(y, Corr):
1096-
if ((self.N != y.N) or (self.T != y.T)):
1100+
if (self.N != y.N) or (self.T != y.T):
10971101
raise ValueError("Addition of Corrs with different shape")
10981102
newcontent = []
10991103
for t in range(self.T):
@@ -1338,21 +1342,21 @@ def __rtruediv__(self, y):
13381342

13391343
@property
13401344
def real(self):
1341-
def return_real(obs_OR_cobs):
1342-
if isinstance(obs_OR_cobs.flatten()[0], CObs):
1343-
return np.vectorize(lambda x: x.real)(obs_OR_cobs)
1345+
def return_real(obs_or_cobs):
1346+
if isinstance(obs_or_cobs.flatten()[0], CObs):
1347+
return np.vectorize(lambda x: x.real)(obs_or_cobs)
13441348
else:
1345-
return obs_OR_cobs
1349+
return obs_or_cobs
13461350

13471351
return self._apply_func_to_corr(return_real)
13481352

13491353
@property
13501354
def imag(self):
1351-
def return_imag(obs_OR_cobs):
1352-
if isinstance(obs_OR_cobs.flatten()[0], CObs):
1353-
return np.vectorize(lambda x: x.imag)(obs_OR_cobs)
1355+
def return_imag(obs_or_cobs):
1356+
if isinstance(obs_or_cobs.flatten()[0], CObs):
1357+
return np.vectorize(lambda x: x.imag)(obs_or_cobs)
13541358
else:
1355-
return obs_OR_cobs * 0 # So it stays the right type
1359+
return obs_or_cobs * 0 # So it stays the right type
13561360

13571361
return self._apply_func_to_corr(return_imag)
13581362

@@ -1396,7 +1400,7 @@ def prune(self, Ntrunc, tproj=3, t0proj=2, basematrix=None):
13961400
if basematrix is None:
13971401
basematrix = self
13981402
if Ntrunc >= basematrix.N:
1399-
raise ValueError('Cannot truncate using Ntrunc <= %d' % (basematrix.N))
1403+
raise ValueError('Cannot truncate using Ntrunc <= %d' % basematrix.N)
14001404
if basematrix.N != self.N:
14011405
raise ValueError('basematrix and targetmatrix have to be of the same size.')
14021406

@@ -1495,7 +1499,7 @@ def eigv(x, **kwargs):
14951499
def matmul(*operands):
14961500
return np.linalg.multi_dot(operands)
14971501
N = Gt.shape[0]
1498-
output = [[] for j in range(N)]
1502+
output = [[] for _ in range(N)]
14991503
if chol_inv is None:
15001504
chol = cholesky(G0) # This will automatically report if the matrix is not pos-def
15011505
chol_inv = inv(chol)

0 commit comments

Comments
 (0)