1+ import warnings
2+
13import numpy as np
24from .util import ensure_rng , NotUniqueError
5+ from .util import Colours
36
47
58def _hashable (x ):
@@ -22,7 +25,9 @@ class TargetSpace(object):
2225 >>> y = space.register_point(x)
2326 >>> assert self.max_point()['max_val'] == y
2427 """
25- def __init__ (self , target_func , pbounds , constraint = None , random_state = None ):
28+
29+ def __init__ (self , target_func , pbounds , constraint = None , random_state = None ,
30+ allow_duplicate_points = False ):
2631 """
2732 Parameters
2833 ----------
@@ -35,8 +40,16 @@ def __init__(self, target_func, pbounds, constraint=None, random_state=None):
3540
3641 random_state : int, RandomState, or None
3742 optionally specify a seed for a random number generator
43+
44+ allow_duplicate_points: bool, optional (default=False)
45+ If True, the optimizer will allow duplicate points to be registered.
46+ This behavior may be desired in high noise situations where repeatedly probing
47+ the same point will give different answers. In other situations, the acquisition
48+ may occasionaly generate a duplicate point.
3849 """
3950 self .random_state = ensure_rng (random_state )
51+ self ._allow_duplicate_points = allow_duplicate_points
52+ self .n_duplicate_points = 0
4053
4154 # The function to be optimized
4255 self .target_func = target_func
@@ -56,7 +69,6 @@ def __init__(self, target_func, pbounds, constraint=None, random_state=None):
5669 # keep track of unique points we have seen so far
5770 self ._cache = {}
5871
59-
6072 self ._constraint = constraint
6173
6274 if constraint is not None :
@@ -96,7 +108,7 @@ def keys(self):
96108 @property
97109 def bounds (self ):
98110 return self ._bounds
99-
111+
100112 @property
101113 def constraint (self ):
102114 return self ._constraint
@@ -176,8 +188,13 @@ def register(self, params, target, constraint_value=None):
176188 """
177189 x = self ._as_array (params )
178190 if x in self :
179- raise NotUniqueError ('Data point {} is not unique' .format (x ))
180-
191+ if self ._allow_duplicate_points :
192+ self .n_duplicate_points = self .n_duplicate_points + 1
193+ print (f'{ Colours .RED } Data point { x } is not unique. { self .n_duplicate_points } duplicates registered.'
194+ f' Continuing ...{ Colours .END } ' )
195+ else :
196+ raise NotUniqueError (f'Data point { x } is not unique. You can set "allow_duplicate_points=True" to '
197+ f'avoid this error' )
181198
182199 self ._params = np .concatenate ([self ._params , x .reshape (1 , - 1 )])
183200 self ._target = np .concatenate ([self ._target , [target ]])
@@ -188,12 +205,12 @@ def register(self, params, target, constraint_value=None):
188205 else :
189206 if constraint_value is None :
190207 msg = ("When registering a point to a constrained TargetSpace" +
191- " a constraint value needs to be present." )
208+ " a constraint value needs to be present." )
192209 raise ValueError (msg )
193210 # Insert data into unique dictionary
194211 self ._cache [_hashable (x .ravel ())] = (target , constraint_value )
195212 self ._constraint_values = np .concatenate ([self ._constraint_values ,
196- [constraint_value ]])
213+ [constraint_value ]])
197214
198215 def probe (self , params ):
199216 """
@@ -215,21 +232,16 @@ def probe(self, params):
215232 target function value.
216233 """
217234 x = self ._as_array (params )
235+ params = dict (zip (self ._keys , x ))
236+ target = self .target_func (** params )
218237
219- try :
220- return self ._cache [_hashable (x )]
221- except KeyError :
222- params = dict (zip (self ._keys , x ))
223- target = self .target_func (** params )
224-
225- if self ._constraint is None :
226- self .register (x , target )
227- return target
228- else :
229- constraint_value = self ._constraint .eval (** params )
230- self .register (x , target , constraint_value )
231- return target , constraint_value
232-
238+ if self ._constraint is None :
239+ self .register (x , target )
240+ return target
241+ else :
242+ constraint_value = self ._constraint .eval (** params )
243+ self .register (x , target , constraint_value )
244+ return target , constraint_value
233245
234246 def random_sample (self ):
235247 """
@@ -317,7 +329,7 @@ def res(self):
317329 self ._constraint_values ,
318330 params ,
319331 self ._constraint .allowed (self ._constraint_values )
320- )
332+ )
321333 ]
322334
323335 def set_bounds (self , new_bounds ):
0 commit comments