33from typing import TYPE_CHECKING , Any , Dict , Optional , Sequence , Tuple , Union
44
55import numpy as np
6+ from arro3 .compute import dictionary_encode
7+ from arro3 .core import (
8+ Array ,
9+ ChunkedArray ,
10+ DataType ,
11+ dictionary_dictionary ,
12+ dictionary_indices ,
13+ )
14+ from arro3 .core .types import ArrowArrayExportable , ArrowStreamExportable
615
716if TYPE_CHECKING :
817 import matplotlib as mpl
@@ -130,7 +139,14 @@ def apply_continuous_cmap(
130139
131140
132141def apply_categorical_cmap (
133- values : Union [NDArray , pd .Series , pa .Array , pa .ChunkedArray ],
142+ values : Union [
143+ NDArray ,
144+ pd .Series ,
145+ pa .Array ,
146+ pa .ChunkedArray ,
147+ ArrowArrayExportable ,
148+ ArrowStreamExportable ,
149+ ],
134150 cmap : DiscreteColormap ,
135151 * ,
136152 alpha : Optional [int ] = None ,
@@ -163,30 +179,27 @@ def apply_categorical_cmap(
163179 dimension will have a length of either `3` if `alpha` is `None`, or `4` is
164180 each color has an alpha value.
165181 """
182+ if isinstance (values , np .ndarray ):
183+ values = Array .from_numpy (values )
184+
166185 try :
167- import pyarrow as pa
168- import pyarrow .compute as pc
169- except ImportError as e :
170- raise ImportError (
171- "pyarrow required for apply_categorical_cmap.\n "
172- "Run `pip install pyarrow`."
173- ) from e
174-
175- # Import from PyCapsule interface
176- if hasattr (values , "__arrow_c_array__" ):
177- values = pa .array (values )
178- elif hasattr (values , "__arrow_c_stream__" ):
179- values = pa .chunked_array (values )
180-
181- # Construct from non-arrow data
182- if not isinstance (values , (pa .Array , pa .ChunkedArray )):
183- values = pa .array (values )
184-
185- if not pa .types .is_dictionary (values .type ):
186- values = pc .dictionary_encode (values )
186+ import pandas as pd
187+
188+ if isinstance (values , pd .Series ):
189+ values = Array .from_numpy (values )
190+ except ImportError :
191+ pass
192+
193+ values = ChunkedArray (values )
194+
195+ if not DataType .is_dictionary (values .type ):
196+ values = ChunkedArray (dictionary_encode (values ))
197+
198+ dictionary = ChunkedArray (dictionary_dictionary (values ))
199+ indices = ChunkedArray (dictionary_indices (values ))
187200
188201 # Build lookup table
189- lut = np .zeros ((len (values . dictionary ), 4 ), dtype = np .uint8 )
202+ lut = np .zeros ((len (dictionary ), 4 ), dtype = np .uint8 )
190203 if alpha is not None :
191204 assert isinstance (alpha , int ), "alpha must be an integer"
192205 assert 0 <= alpha <= 255 , "alpha must be between 0-255 (inclusive)."
@@ -195,7 +208,7 @@ def apply_categorical_cmap(
195208 else :
196209 lut [:, 3 ] = 255
197210
198- for i , key in enumerate (values . dictionary ):
211+ for i , key in enumerate (dictionary ):
199212 color = cmap [key .as_py ()]
200213 if len (color ) == 3 :
201214 lut [i , :3 ] = color
@@ -206,7 +219,7 @@ def apply_categorical_cmap(
206219 "Expected color to be 3 or 4 values representing RGB or RGBA."
207220 )
208221
209- colors = lut [values . indices ]
222+ colors = lut [indices ]
210223
211224 # If the alpha values are all 255, don't serialize
212225 if (colors [:, 3 ] == 255 ).all ():
0 commit comments