diff --git a/xdggs/plotting.py b/xdggs/plotting.py index 870c914..1f11352 100644 --- a/xdggs/plotting.py +++ b/xdggs/plotting.py @@ -7,7 +7,16 @@ import ipywidgets import numpy as np import xarray as xr -from lonboard import BaseLayer, Map +from lonboard import BaseLayer +from lonboard import Map as LonboardMap + + +@dataclass +class Container: + obj: xr.DataArray + colorize_kwargs: dict[str, Any] + layer: BaseLayer + dimension_sliders: list[ipywidgets.IntSlider] def on_slider_change(change, container): @@ -16,89 +25,55 @@ def on_slider_change(change, container): indexers = { slider.description: slider.value - for slider in container.dimension_sliders.children + for slider in container.dimension_sliders if slider.description != dim } | {dim: change["new"]} - new_slice = container.obj.isel(indexers) + new_slice = container.obj.isel(indexers) colors = colorize(new_slice.variable, **container.colorize_kwargs) - layer = container.map.layers[0] + layer = container.layer layer.get_fill_color = colors -@dataclass -class MapContainer: - """container for the map, any control widgets and the data object""" - - dimension_sliders: ipywidgets.VBox - map: Map - obj: xr.DataArray - - colorize_kwargs: dict[str, Any] +def render_map( + map_: Map, dimension_sliders: list[ipywidgets.IntSlider] +) -> Map | MapWithSliders: + if not dimension_sliders: + return map_ - def render(self): - # add any additional control widgets here - control_box = ipywidgets.HBox([self.dimension_sliders]) + slider_box = ipywidgets.VBox(dimension_sliders) + control_box = ipywidgets.HBox([slider_box]) - return MapWithSliders( - [self.map, control_box], layout=ipywidgets.Layout(width="100%") - ) + return MapWithSliders([map_, control_box], layout=ipywidgets.Layout(width="100%")) -def extract_maps(obj: MapGrid | MapWithSliders | Map): - if isinstance(obj, Map): - return obj +def extract_maps(obj: MapGrid | MapWithSliders | Map | LonboardMap): + if isinstance(obj, (Map, LonboardMap)): + return [obj] return getattr(obj, "maps", (obj.map,)) -class MapGrid(ipywidgets.GridBox): - def __init__( - self, - maps: MapWithSliders | Map = None, - n_columns: int = 2, - synchronize: bool = False, - ): - self.n_columns = n_columns - self.synchronize = synchronize +class Map(LonboardMap): + def __or__(self, other: Map | MapWithSliders): + if isinstance(other, MapGrid): + return NotImplemented - column_width = 100 // n_columns - layout = ipywidgets.Layout( - width="100%", grid_template_columns=f"repeat({n_columns}, {column_width}%)" - ) + return MapGrid([self, other]) - if maps is None: - maps = [] - - if synchronize and maps: - all_maps = [getattr(m, "map", m) for m in maps] - - first = all_maps[0] - for second in all_maps[1:]: - ipywidgets.jslink((first, "view_state"), (second, "view_state")) - - super().__init__(maps, layout=layout) - - def _replace_maps(self, maps): - return type(self)(maps, n_columns=self.n_columns, synchronize=self.synchronize) - - def add_map(self, map_: MapWithSliders | Map): - return self._replace_maps(self.maps + (map_,)) + def __and__(self, other): + if isinstance(other, (MapWithSliders, MapGrid)): + return NotImplemented - @property - def maps(self): - return self.children - - def __or__(self, other: MapGrid | MapWithSliders | Map): - other_maps = extract_maps(other) - - return self._replace_maps(self.maps + other_maps) + if isinstance(other, BaseLayer): + other_layers = [other] + else: + other_layers = list(other.layers) - def __ror__(self, other: MapWithSliders | Map): - other_maps = extract_maps(other) + layers = list(self.layers) + list(other_layers) - return self._replace_maps(self.maps + other_maps) + return type(self)(layers) class MapWithSliders(ipywidgets.VBox): @@ -122,6 +97,11 @@ def __or__(self, other: MapWithSliders | Map): return MapGrid([self, other], synchronize=True) + def __ror__(self, other: Map): + [other_map] = extract_maps(other) + + return MapGrid([other_map, self], synchronize=True) + def _merge(self, layers, sliders): all_layers = list(self.map.layers) + list(layers) new_map = Map(all_layers) @@ -142,6 +122,9 @@ def add_layer(self, layer: BaseLayer): self.map.add_layer(layer) def __and__(self, other: MapWithSliders | Map | BaseLayer): + if isinstance(other, MapGrid): + return NotImplemented + if isinstance(other, BaseLayer): layers = [other] sliders = [] @@ -151,6 +134,63 @@ def __and__(self, other: MapWithSliders | Map | BaseLayer): return self._merge(layers, sliders) + def __rand__(self, other: Map | BaseLayer): + return self & other + + +class MapGrid(ipywidgets.GridBox): + def __init__( + self, + maps: MapWithSliders | Map = None, + n_columns: int = 2, + synchronize: bool = False, + ): + self.n_columns = n_columns + self.synchronize = synchronize + + column_width = 100 // n_columns + layout = ipywidgets.Layout( + width="100%", grid_template_columns=f"repeat({n_columns}, {column_width}%)" + ) + + if maps is None: + maps = [] + + super().__init__(maps, layout=layout) + + if synchronize and maps: + self.synchronize_maps() + + def _replace_maps(self, maps): + return type(self)(maps, n_columns=self.n_columns, synchronize=self.synchronize) + + def add_map(self, map_: MapWithSliders | Map): + return self._replace_maps(self.maps + (map_,)) + + @property + def maps(self): + return self.children + + def synchronize_maps(self): + if not self.maps: + raise ValueError("no maps to synchronize found") + + all_maps = [getattr(m, "map", m) for m in self.maps] + + first = all_maps[0] + for second in all_maps[1:]: + ipywidgets.jslink((first, "view_state"), (second, "view_state")) + + def __or__(self, other: MapGrid | MapWithSliders | Map): + other_maps = extract_maps(other) + + return self._replace_maps(self.maps + other_maps) + + def __ror__(self, other: MapWithSliders | Map): + other_maps = extract_maps(other) + + return self._replace_maps(self.maps + other_maps) + def create_arrow_table(polygons, arr, coords=None): from arro3.core import Array, ChunkedArray, Schema, Table @@ -205,7 +245,6 @@ def explore( alpha=None, coords=None, ): - import lonboard from lonboard import SolidPolygonLayer from matplotlib import colormaps @@ -227,29 +266,29 @@ def explore( table = create_arrow_table(polygons, initial_arr, coords=coords) layer = SolidPolygonLayer(table=table, filled=True, get_fill_color=colors) - map_ = lonboard.Map(layer) + map_ = LonboardMap(layer) - if not initial_indexers: - # 1D data - return map_ + sliders = [ + ipywidgets.IntSlider(min=0, max=arr.sizes[dim] - 1, description=dim) + for dim in arr.dims + if dim != cell_dim + ] - sliders = ipywidgets.VBox( - [ - ipywidgets.IntSlider(min=0, max=arr.sizes[dim] - 1, description=dim) - for dim in arr.dims - if dim != cell_dim - ] - ) + map_object = render_map(map_, sliders) - container = MapContainer( - sliders, - map_, + container = Container( arr, - colorize_kwargs={"alpha": alpha, "center": center, "colormap": colormap}, + colorize_kwargs={ + "alpha": alpha, + "center": center, + "colormap": colormap, + }, + layer=layer, + dimension_sliders=sliders, ) # connect slider with map - for slider in sliders.children: + for slider in sliders: slider.observe(partial(on_slider_change, container=container), names="value") - return container.render() + return map_object