Skip to content

Commit 07187a0

Browse files
authored
Use arro3 for dictionary encoding in apply_categorical_cmap (#601)
1 parent 213f26b commit 07187a0

File tree

4 files changed

+91
-80
lines changed

4 files changed

+91
-80
lines changed

lonboard/colormap.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple, Union
44

55
import numpy as np
6+
from arro3.compute import dictionary_encode
7+
from arro3.core import (
8+
Array,
9+
ChunkedArray,
10+
DataType,
11+
dictionary_dictionary,
12+
dictionary_indices,
13+
)
14+
from arro3.core.types import ArrowArrayExportable, ArrowStreamExportable
615

716
if TYPE_CHECKING:
817
import matplotlib as mpl
@@ -130,7 +139,14 @@ def apply_continuous_cmap(
130139

131140

132141
def apply_categorical_cmap(
133-
values: Union[NDArray, pd.Series, pa.Array, pa.ChunkedArray],
142+
values: Union[
143+
NDArray,
144+
pd.Series,
145+
pa.Array,
146+
pa.ChunkedArray,
147+
ArrowArrayExportable,
148+
ArrowStreamExportable,
149+
],
134150
cmap: DiscreteColormap,
135151
*,
136152
alpha: Optional[int] = None,
@@ -163,30 +179,27 @@ def apply_categorical_cmap(
163179
dimension will have a length of either `3` if `alpha` is `None`, or `4` is
164180
each color has an alpha value.
165181
"""
182+
if isinstance(values, np.ndarray):
183+
values = Array.from_numpy(values)
184+
166185
try:
167-
import pyarrow as pa
168-
import pyarrow.compute as pc
169-
except ImportError as e:
170-
raise ImportError(
171-
"pyarrow required for apply_categorical_cmap.\n"
172-
"Run `pip install pyarrow`."
173-
) from e
174-
175-
# Import from PyCapsule interface
176-
if hasattr(values, "__arrow_c_array__"):
177-
values = pa.array(values)
178-
elif hasattr(values, "__arrow_c_stream__"):
179-
values = pa.chunked_array(values)
180-
181-
# Construct from non-arrow data
182-
if not isinstance(values, (pa.Array, pa.ChunkedArray)):
183-
values = pa.array(values)
184-
185-
if not pa.types.is_dictionary(values.type):
186-
values = pc.dictionary_encode(values)
186+
import pandas as pd
187+
188+
if isinstance(values, pd.Series):
189+
values = Array.from_numpy(values)
190+
except ImportError:
191+
pass
192+
193+
values = ChunkedArray(values)
194+
195+
if not DataType.is_dictionary(values.type):
196+
values = ChunkedArray(dictionary_encode(values))
197+
198+
dictionary = ChunkedArray(dictionary_dictionary(values))
199+
indices = ChunkedArray(dictionary_indices(values))
187200

188201
# Build lookup table
189-
lut = np.zeros((len(values.dictionary), 4), dtype=np.uint8)
202+
lut = np.zeros((len(dictionary), 4), dtype=np.uint8)
190203
if alpha is not None:
191204
assert isinstance(alpha, int), "alpha must be an integer"
192205
assert 0 <= alpha <= 255, "alpha must be between 0-255 (inclusive)."
@@ -195,7 +208,7 @@ def apply_categorical_cmap(
195208
else:
196209
lut[:, 3] = 255
197210

198-
for i, key in enumerate(values.dictionary):
211+
for i, key in enumerate(dictionary):
199212
color = cmap[key.as_py()]
200213
if len(color) == 3:
201214
lut[i, :3] = color
@@ -206,7 +219,7 @@ def apply_categorical_cmap(
206219
"Expected color to be 3 or 4 values representing RGB or RGBA."
207220
)
208221

209-
colors = lut[values.indices]
222+
colors = lut[indices]
210223

211224
# If the alpha values are all 255, don't serialize
212225
if (colors[:, 3] == 255).all():

poetry.lock

Lines changed: 46 additions & 46 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ classifiers = [
3232
[tool.poetry.dependencies]
3333
python = "^3.8"
3434
anywidget = "^0.9.0"
35-
arro3-core = "^0.3.0-beta.1"
36-
arro3-io = "^0.3.0-beta.1"
37-
arro3-compute = "^0.3.0-beta.1"
35+
arro3-core = "^0.3.0-beta.2"
36+
arro3-io = "^0.3.0-beta.2"
37+
arro3-compute = "^0.3.0-beta.2"
3838
ipywidgets = ">=7.6.0"
3939
numpy = ">=1.14"
4040
# The same version pin as geopandas

tests/test_colormap.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
1-
import pytest
1+
from arro3.core import Array, DataType
22

33
from lonboard.colormap import apply_categorical_cmap
44

55

66
def test_discrete_cmap():
7-
pd = pytest.importorskip("pandas")
8-
9-
values = ["red", "green", "blue", "blue", "red"]
10-
df = pd.DataFrame({"val": values})
7+
str_values = ["red", "green", "blue", "blue", "red"]
8+
values = Array(str_values, type=DataType.string())
119
cmap = {
1210
"red": [255, 0, 0],
1311
"green": [0, 255, 0],
1412
"blue": [0, 0, 255],
1513
}
16-
colors = apply_categorical_cmap(df["val"], cmap)
14+
colors = apply_categorical_cmap(values, cmap)
1715

18-
for i, val in enumerate(values):
16+
for i, val in enumerate(str_values):
1917
assert list(colors[i]) == cmap[val]

0 commit comments

Comments
 (0)