44from enum import IntEnum
55from functools import cached_property
66
7- import numpy as np
8- from numpy_groupies import aggregate_nb as aggregate
7+ from array_api_compat import is_numpy_array , is_numpy_namespace , is_torch_array
98
10- from ribs ._utils import readonly
9+ from ribs ._utils import arr_readonly , xp_namespace
1110from ribs .archives ._archive_data_frame import ArchiveDataFrame
1211
1312
@@ -36,7 +35,7 @@ def __next__(self):
3635
3736 Raises RuntimeError if the store was modified.
3837 """
39- if not np . all ( self .state == self .store ._props ["updates" ]) :
38+ if self .state != self .store ._props ["updates" ]:
4039 # This check should go before the StopIteration check because a call
4140 # to clear() would cause the len(self.store) to be 0 and thus
4241 # trigger StopIteration.
@@ -61,8 +60,8 @@ class ArrayStore:
6160 """Maintains a set of arrays that share a common dimension.
6261
6362 The ArrayStore consists of several *fields* of data that are manipulated
64- simultaneously via batch operations. Each field is a NumPy array with a
65- dimension of ``(capacity, ...)`` and can be of any type.
63+ simultaneously via batch operations. Each field is an array with a dimension
64+ of ``(capacity, ...)`` and can be of any type.
6665
6766 Since the arrays all share a common first dimension, they also share a
6867 common index. For instance, if we :meth:`retrieve` the data at indices ``[0,
@@ -77,6 +76,12 @@ class ArrayStore:
7776 The ArrayStore supports several further operations, such as an :meth:`add`
7877 method that inserts data into the ArrayStore.
7978
79+ By default, the arrays in the ArrayStore are NumPy arrays. However, through
80+ support for the `Python array API standard
81+ <https://data-apis.org/array-api/latest/>`_, it is possible to use arrays
82+ from other libraries like PyTorch by passing in arguments for ``xp`` and
83+ ``device``.
84+
8085 Args:
8186 field_desc (dict): Description of fields in the array store. The
8287 description is a dict mapping from a str to a tuple of ``(shape,
@@ -86,6 +91,10 @@ class ArrayStore:
8691 ``(capacity, 10)``. Note that field names must be valid Python
8792 identifiers.
8893 capacity (int): Total possible entries in the store.
94+ xp (array_namespace): Optional array namespace. Should be compatible
95+ with the array API standard, or supported by array-api-compat.
96+ Defaults to ``numpy``.
97+ device (device): Device for arrays.
8998
9099 Attributes:
91100 _props (dict): Properties that are common to every ArrayStore.
@@ -97,7 +106,7 @@ class ArrayStore:
97106 * "occupied_list": Array of size ``(capacity,)`` listing all
98107 occupied indices in the store. Only the first ``n_occupied``
99108 elements will be valid.
100- * "updates": Int array recording number of calls to functions that
109+ * "updates": Int list recording number of calls to functions that
101110 modified the store.
102111
103112 _fields (dict): Holds all the arrays with their data.
@@ -109,13 +118,22 @@ class ArrayStore:
109118 valid Python identifier.
110119 """
111120
112- def __init__ (self , field_desc , capacity ):
121+ def __init__ (self , field_desc , capacity , xp = None , device = None ):
122+ self ._xp = xp_namespace (xp )
123+ self ._device = device
124+
113125 self ._props = {
114- "capacity" : capacity ,
115- "occupied" : np .zeros (capacity , dtype = bool ),
116- "n_occupied" : 0 ,
117- "occupied_list" : np .empty (capacity , dtype = np .int32 ),
118- "updates" : np .array ([0 , 0 ]),
126+ "capacity" :
127+ capacity ,
128+ "occupied" :
129+ self ._xp .zeros (capacity , dtype = bool , device = self ._device ),
130+ "n_occupied" :
131+ 0 ,
132+ "occupied_list" :
133+ self ._xp .empty (capacity ,
134+ dtype = self ._xp .int32 ,
135+ device = self ._device ),
136+ "updates" : [0 , 0 ],
119137 }
120138
121139 self ._fields = {}
@@ -130,7 +148,9 @@ def __init__(self, field_desc, capacity):
130148 field_shape = (field_shape ,)
131149
132150 array_shape = (capacity ,) + tuple (field_shape )
133- self ._fields [name ] = np .empty (array_shape , dtype )
151+ self ._fields [name ] = self ._xp .empty (array_shape ,
152+ dtype = dtype ,
153+ device = self ._device )
134154
135155 def __len__ (self ):
136156 """Number of occupied indices in the store, i.e., number of indices that
@@ -163,15 +183,14 @@ def capacity(self):
163183
164184 @property
165185 def occupied (self ):
166- """numpy.ndarray : Boolean array of size ``(capacity,)`` indicating
167- whether each index has a data entry."""
168- return readonly (self ._props ["occupied" ]. view () )
186+ """array : Boolean array of size ``(capacity,)`` indicating whether each
187+ index has a data entry."""
188+ return arr_readonly (self ._props ["occupied" ])
169189
170190 @property
171191 def occupied_list (self ):
172- """numpy.ndarray: int32 array listing all occupied indices in the
173- store."""
174- return readonly (
192+ """array: int32 array listing all occupied indices in the store."""
193+ return arr_readonly (
175194 self ._props ["occupied_list" ][:self ._props ["n_occupied" ]])
176195
177196 @cached_property
@@ -211,10 +230,14 @@ def dtypes(self):
211230 "measures": np.float32,
212231 }
213232 """
214- # Calling `.type` retrieves the numpy scalar type, which is callable:
215- # - https://numpy.org/doc/stable/reference/arrays.scalars.html
216- # - https://numpy.org/doc/stable/reference/arrays.dtypes.html
217- return {name : arr .dtype .type for name , arr in self ._fields .items ()}
233+ if is_numpy_namespace (self ._xp ):
234+ # TODO (#577): In NumPy, we currently want the scalar type (i.e.,
235+ # arr.dtype.type rather than arr.dtype), which is callable.
236+ # Ultimately, this should be switched to just be the dtype to be
237+ # compatible across array libraries.
238+ return {name : arr .dtype .type for name , arr in self ._fields .items ()}
239+ else :
240+ return {name : arr .dtype for name , arr in self ._fields .items ()}
218241
219242 @cached_property
220243 def dtypes_with_index (self ):
@@ -230,7 +253,7 @@ def dtypes_with_index(self):
230253 "index": np.int32,
231254 }
232255 """
233- return self .dtypes | {"index" : np .int32 }
256+ return self .dtypes | {"index" : self . _xp .int32 }
234257
235258 @cached_property
236259 def field_list (self ):
@@ -261,15 +284,29 @@ def field_list_with_index(self):
261284 """
262285 return list (self ._fields ) + ["index" ]
263286
287+ @staticmethod
288+ def _convert_to_numpy (arr ):
289+ """If needed, converts the given array to a numpy array for the pandas
290+ return type in `retrieve`."""
291+ if is_numpy_array (arr ):
292+ return arr
293+ elif is_torch_array (arr ):
294+ return arr .cpu ().detach ().numpy ()
295+ else :
296+ raise NotImplementedError (
297+ "The pandas return type is currently only supported "
298+ "with numpy and torch arrays." )
299+
264300 def retrieve (self , indices , fields = None , return_type = "dict" ):
265301 """Collects data at the given indices.
266302
267303 Args:
268304 indices (array-like): List of indices at which to collect data.
269305 fields (str or array-like of str): List of fields to include. By
270306 default, all fields will be included, with an additional "index"
271- as the last field ("index" can also be placed anywhere in this
272- list). This can also be a single str indicating a field name.
307+ as the last field. The "index" field can also be added anywhere
308+ in this list of fields. This argument can also be a single str
309+ indicating a field name.
273310 return_type (str): Type of data to return. See the ``data`` returned
274311 below. Ignored if ``fields`` is a str.
275312
@@ -346,6 +383,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
346383 Like the other return types, the columns can be adjusted with
347384 the ``fields`` parameter.
348385
386+ .. note:: This return type will require copying all fields in
387+ the ArrayStore into NumPy arrays, if they are not already
388+ NumPy arrays.
389+
349390 All data returned by this method will be a copy, i.e., the data will
350391 not update as the store changes.
351392
@@ -354,8 +395,12 @@ def retrieve(self, indices, fields=None, return_type="dict"):
354395 ValueError: Invalid return_type provided.
355396 """
356397 single_field = isinstance (fields , str )
357- indices = np .asarray (indices , dtype = np .int32 )
358- occupied = self ._props ["occupied" ][indices ] # Induces copy.
398+ indices = self ._xp .asarray (indices ,
399+ dtype = self ._xp .int32 ,
400+ device = self ._device )
401+
402+ # Induces copy (in numpy, at least).
403+ occupied = self ._props ["occupied" ][indices ]
359404
360405 if single_field :
361406 data = None
@@ -374,10 +419,10 @@ def retrieve(self, indices, fields=None, return_type="dict"):
374419 for name in fields :
375420 # Collect array data.
376421 #
377- # Note that fancy indexing with indices already creates a copy, so
378- # only `indices` needs to be copied explicitly.
422+ # Note that fancy indexing with indices already creates a copy (in
423+ # numpy, at least), so only `indices` needs to be copied explicitly.
379424 if name == "index" :
380- arr = np . copy (indices )
425+ arr = self . _xp . asarray (indices , copy = True )
381426 elif name in self ._fields :
382427 arr = self ._fields [name ][indices ] # Induces copy.
383428 else :
@@ -391,6 +436,8 @@ def retrieve(self, indices, fields=None, return_type="dict"):
391436 elif return_type == "tuple" :
392437 data .append (arr )
393438 elif return_type == "pandas" :
439+ arr = self ._convert_to_numpy (arr )
440+
394441 if len (arr .shape ) == 1 : # Scalar entries.
395442 data [name ] = arr
396443 elif len (arr .shape ) == 2 : # 1D array entries.
@@ -405,6 +452,8 @@ def retrieve(self, indices, fields=None, return_type="dict"):
405452 if return_type == "tuple" :
406453 data = tuple (data )
407454 elif return_type == "pandas" :
455+ occupied = self ._convert_to_numpy (occupied )
456+
408457 # Data above are already copied, so no need to copy again.
409458 data = ArchiveDataFrame (data , copy = False )
410459
@@ -471,8 +520,16 @@ def add(self, indices, data):
471520 "This can also occur if the archive and result_archive have "
472521 "different extra_fields." )
473522
523+ # Determine the unique indices. These operations are preferred over
524+ # `xp.unique_values(indices)` because they operate in linear time, while
525+ # unique_values usually sorts the input.
526+ indices_occupied = self ._xp .zeros (self .capacity ,
527+ dtype = bool ,
528+ device = self ._device )
529+ indices_occupied [indices ] = True
530+ unique_indices = self ._xp .nonzero (indices_occupied )[0 ]
531+
474532 # Update occupancy data.
475- unique_indices = np .where (aggregate (indices , 1 , func = "len" ) != 0 )[0 ]
476533 cur_occupied = self ._props ["occupied" ][unique_indices ]
477534 new_indices = unique_indices [~ cur_occupied ]
478535 n_occupied = self ._props ["n_occupied" ]
@@ -483,16 +540,18 @@ def add(self, indices, data):
483540
484541 # Insert into the ArrayStore. Note that we do not assume indices are
485542 # unique. Hence, when updating occupancy data above, we computed the
486- # unique indices. In contrast, here we let NumPy 's default behavior
543+ # unique indices. In contrast, here we let the array 's default behavior
487544 # handle duplicate indices.
488545 for name , arr in self ._fields .items ():
489- arr [indices ] = data [name ]
546+ arr [indices ] = self ._xp .asarray (data [name ],
547+ dtype = arr .dtype ,
548+ device = self ._device )
490549
491550 def clear (self ):
492551 """Removes all entries from the store."""
493552 self ._props ["updates" ][Update .CLEAR ] += 1
494553 self ._props ["n_occupied" ] = 0 # Effectively clears occupied_list too.
495- self ._props ["occupied" ]. fill ( False )
554+ self ._props ["occupied" ][:] = False
496555
497556 def resize (self , capacity ):
498557 """Resizes the store to the given capacity.
@@ -512,14 +571,20 @@ def resize(self, capacity):
512571 self ._props ["capacity" ] = capacity
513572
514573 cur_occupied = self ._props ["occupied" ]
515- self ._props ["occupied" ] = np .zeros (capacity , dtype = bool )
574+ self ._props ["occupied" ] = self ._xp .zeros (capacity ,
575+ dtype = bool ,
576+ device = self ._device )
516577 self ._props ["occupied" ][:cur_capacity ] = cur_occupied
517578
518579 cur_occupied_list = self ._props ["occupied_list" ]
519- self ._props ["occupied_list" ] = np .empty (capacity , dtype = np .int32 )
580+ self ._props ["occupied_list" ] = self ._xp .empty (capacity ,
581+ dtype = self ._xp .int32 ,
582+ device = self ._device )
520583 self ._props ["occupied_list" ][:cur_capacity ] = cur_occupied_list
521584
522585 for name , cur_arr in self ._fields .items ():
523586 new_shape = (capacity ,) + cur_arr .shape [1 :]
524- self ._fields [name ] = np .empty (new_shape , cur_arr .dtype )
587+ self ._fields [name ] = self ._xp .empty (new_shape ,
588+ dtype = cur_arr .dtype ,
589+ device = self ._device )
525590 self ._fields [name ][:cur_capacity ] = cur_arr
0 commit comments