@@ -51,15 +51,15 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
5151 def is_block_size (self ) -> bool :
5252 return False
5353
54- def get_minimum (self ) -> int :
54+ def dim (self ) -> int :
5555 """
56- Return the minimum allowed value for this fragment.
56+ Returns the dimension of the output of encode
5757 """
5858 raise NotImplementedError
5959
60- def encode_scalar (self , value : object ) -> float :
60+ def encode (self , value : object ) -> list [ float ] :
6161 """
62- Encode a configuration value into a float for ML models.
62+ Encode a configuration value into a list of floats for ML models.
6363
6464 This is used by surrogate-assisted algorithms to convert configurations
6565 into numerical vectors for prediction models.
@@ -68,14 +68,15 @@ def encode_scalar(self, value: object) -> float:
6868 value: The configuration value to encode.
6969
7070 Returns:
71- A float representing the encoded value.
71+ A list of floats representing the encoded value.
7272 """
73- # Default: convert to float if possible
74- if not isinstance (value , (int , float , bool )):
75- raise TypeError (
76- f"Cannot encode { type (value ).__name__ } value { value !r} for ML"
77- )
78- return float (value )
73+ raise NotImplementedError
74+
75+ def get_minimum (self ) -> int :
76+ """
77+ Return the minimum allowed value for this fragment.
78+ """
79+ raise NotImplementedError
7980
8081
8182@dataclasses .dataclass
@@ -106,6 +107,17 @@ def pattern_neighbors(self, current: object) -> list[object]:
106107 neighbors .append (swapped )
107108 return neighbors
108109
110+ def dim (self ) -> int :
111+ return self .length
112+
113+ def encode (self , value : object ) -> list [float ]:
114+ assert isinstance (value , list )
115+ encoded = []
116+ for val in value :
117+ assert isinstance (val , int )
118+ encoded .append (float (val ))
119+ return value
120+
109121
110122@dataclasses .dataclass
111123class BaseIntegerFragment (ConfigSpecFragment ):
@@ -129,6 +141,9 @@ def clamp(self, val: int) -> int:
129141 def get_minimum (self ) -> int :
130142 return self .low
131143
144+ def dim (self ) -> int :
145+ return 1
146+
132147 def pattern_neighbors (self , current : object ) -> list [object ]:
133148 if type (current ) is not int : # bool is not allowed
134149 raise TypeError (f"Expected int, got { type (current ).__name__ } " )
@@ -141,13 +156,9 @@ def pattern_neighbors(self, current: object) -> list[object]:
141156 neighbors .append (upper )
142157 return neighbors
143158
144- def encode_scalar (self , value : object ) -> float :
145- """Encode integer values directly as floats."""
146- if not isinstance (value , (int , float )):
147- raise TypeError (
148- f"Expected int/float for BaseIntegerFragment, got { type (value ).__name__ } : { value !r} "
149- )
150- return float (value )
159+ def encode (self , value : object ) -> list [float ]:
160+ assert isinstance (value , int )
161+ return [float (value )]
151162
152163
153164class PowerOfTwoFragment (BaseIntegerFragment ):
@@ -180,7 +191,7 @@ def differential_mutation(self, a: object, b: object, c: object) -> int:
180191 return self .clamp (ai * 2 )
181192 return ai
182193
183- def encode_scalar (self , value : object ) -> float :
194+ def encode (self , value : object ) -> list [ float ] :
184195 """Encode power-of-2 values using log2 transformation."""
185196 import math
186197
@@ -192,7 +203,7 @@ def encode_scalar(self, value: object) -> float:
192203 raise ValueError (
193204 f"Expected positive value for PowerOfTwoFragment, got { value } "
194205 )
195- return math .log2 (float (value ))
206+ return [ math .log2 (float (value ))]
196207
197208
198209class IntegerFragment (BaseIntegerFragment ):
@@ -235,7 +246,10 @@ def differential_mutation(self, a: object, b: object, c: object) -> object:
235246 choices .remove (a )
236247 return random .choice (choices )
237248
238- def encode_scalar (self , value : object ) -> float :
249+ def dim (self ) -> int :
250+ return len (self .choices )
251+
252+ def encode (self , value : object ) -> list [float ]:
239253 """Encode enum values as their index."""
240254 try :
241255 choice_idx = self .choices .index (value )
@@ -244,7 +258,7 @@ def encode_scalar(self, value: object) -> float:
244258 f"Invalid enum value { value !r} for EnumFragment. "
245259 f"Valid choices: { self .choices } "
246260 ) from None
247- return float ( choice_idx )
261+ return [ 1.0 if i == choice_idx else 0.0 for i in range ( len ( self . choices ))]
248262
249263
250264class BooleanFragment (ConfigSpecFragment ):
@@ -265,6 +279,14 @@ def differential_mutation(self, a: object, b: object, c: object) -> bool:
265279 return a
266280 return not a
267281
282+ def dim (self ) -> int :
283+ return 1
284+
285+ def encode (self , value : object ) -> list [float ]:
286+ """Encode enum values as their index."""
287+ assert isinstance (value , bool )
288+ return [1.0 ] if value else [0.0 ]
289+
268290
269291class BlockSizeFragment (PowerOfTwoFragment ):
270292 def category (self ) -> Category :
@@ -320,3 +342,13 @@ def differential_mutation(self, a: object, b: object, c: object) -> list[object]
320342 self .inner .differential_mutation (a [i ], b [i ], c [i ])
321343 for i in range (self .length )
322344 ]
345+
346+ def dim (self ) -> int :
347+ return self .length * self .inner .dim ()
348+
349+ def encode (self , value : object ) -> list [float ]:
350+ assert isinstance (value , list )
351+ encoded = []
352+ for v in value :
353+ encoded .extend (self .inner .encode (v ))
354+ return encoded
0 commit comments