11"""Test cases for GroupBy.plot"""
22
3+ import matplotlib .pyplot as plt
34import numpy as np
45import pytest
56
@@ -156,10 +157,6 @@ def test_groupby_hist_series_with_legend_raises(self):
156157 def test_groupby_scatter_colors_differ (self ):
157158 # GH 59846 - Test that scatter plots use different colors for different groups
158159 # similar to how line plots do
159- from matplotlib .collections import PathCollection
160- import matplotlib .pyplot as plt
161-
162- # Create test data with distinct groups
163160 df = DataFrame (
164161 {
165162 "x" : [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ],
@@ -168,36 +165,20 @@ def test_groupby_scatter_colors_differ(self):
168165 }
169166 )
170167
171- # Set up a figure with both line and scatter plots
172168 fig , (ax1 , ax2 ) = plt .subplots (1 , 2 )
173-
174- # Plot line chart (known to use different colors for different groups)
175169 df .groupby ("group" ).plot (x = "x" , y = "y" , ax = ax1 , kind = "line" )
176-
177- # Plot scatter chart (should also use different colors for different groups)
178170 df .groupby ("group" ).plot (x = "x" , y = "y" , ax = ax2 , kind = "scatter" )
179171
180- # Get the colors used in the line plot and scatter plot
181172 line_colors = [line .get_color () for line in ax1 .get_lines ()]
173+ scatter_colors = [
174+ tuple (tuple (fc ) for fc in scatter .get_facecolor ())
175+ for scatter in ax2 .collections
176+ ]
182177
183- # Get scatter colors
184- scatter_colors = []
185- for collection in ax2 .collections :
186- if isinstance (collection , PathCollection ): # This is a scatter plot
187- # Get the face colors (might be array of RGBA values)
188- face_colors = collection .get_facecolor ()
189- # If multiple points with same color, we get the first one
190- if face_colors .ndim > 1 :
191- scatter_colors .append (tuple (face_colors [0 ]))
192- else :
193- scatter_colors .append (tuple (face_colors ))
194-
195- # Assert that we have the right number of colors (one per group)
196178 assert len (line_colors ) == 3
197179 assert len (scatter_colors ) == 3
198180
199- # Assert that the colors are all different
181+ assert len ( set ( line_colors )) == 3
200182 assert len (set (scatter_colors )) == 3
201- assert len (line_colors ) == 3
202183
203184 plt .close (fig )
0 commit comments