88import numpy as np
99import pytest
1010
11- from pandas ._config import using_string_dtype
12-
1311from pandas import (
1412 Categorical ,
1513 DataFrame ,
@@ -106,10 +104,9 @@ def _call_and_check(klass, msg, how, gb, groupby_func, args, warn_msg=""):
106104 gb .transform (groupby_func , * args )
107105
108106
109- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False )
110107@pytest .mark .parametrize ("how" , ["method" , "agg" , "transform" ])
111108def test_groupby_raises_string (
112- how , by , groupby_series , groupby_func , df_with_string_col
109+ how , by , groupby_series , groupby_func , df_with_string_col , using_infer_string
113110):
114111 df = df_with_string_col
115112 args = get_groupby_method_args (groupby_func , df )
@@ -183,6 +180,44 @@ def test_groupby_raises_string(
183180 ),
184181 }[groupby_func ]
185182
183+ if using_infer_string :
184+ if klass is not None :
185+ if re .escape ("agg function failed" ) in msg :
186+ msg = msg .replace ("object" , "string" )
187+ elif groupby_func in [
188+ "cumsum" ,
189+ "cumprod" ,
190+ "cummin" ,
191+ "cummax" ,
192+ "std" ,
193+ "sem" ,
194+ "skew" ,
195+ ]:
196+ msg = msg .replace ("object" , "string" )
197+ elif groupby_func == "quantile" :
198+ msg = "No matching signature found"
199+ elif groupby_func == "corrwith" :
200+ msg = (
201+ "'ArrowStringArrayNumpySemantics' with dtype string does "
202+ "not support operation 'mean'"
203+ )
204+ else :
205+ import pyarrow as pa
206+
207+ klass = pa .lib .ArrowNotImplementedError
208+ if groupby_func == "pct_change" :
209+ msg = "Function 'divide' has no kernel matching input types"
210+ elif groupby_func == "diff" :
211+ msg = (
212+ "Function 'subtract_checked' has no kernel matching "
213+ "input types"
214+ )
215+ else :
216+ msg = (
217+ f"Function '{ groupby_func } ' has no kernel matching "
218+ "input types"
219+ )
220+
186221 if groupby_func == "fillna" :
187222 kind = "Series" if groupby_series else "DataFrame"
188223 warn_msg = f"{ kind } GroupBy.fillna is deprecated"
@@ -208,11 +243,15 @@ def func(x):
208243 getattr (gb , how )(func )
209244
210245
211- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" )
212246@pytest .mark .parametrize ("how" , ["agg" , "transform" ])
213247@pytest .mark .parametrize ("groupby_func_np" , [np .sum , np .mean ])
214248def test_groupby_raises_string_np (
215- how , by , groupby_series , groupby_func_np , df_with_string_col
249+ how ,
250+ by ,
251+ groupby_series ,
252+ groupby_func_np ,
253+ df_with_string_col ,
254+ using_infer_string ,
216255):
217256 # GH#50749
218257 df = df_with_string_col
@@ -228,6 +267,15 @@ def test_groupby_raises_string_np(
228267 "Could not convert string .* to numeric" ,
229268 ),
230269 }[groupby_func_np ]
270+
271+ if using_infer_string :
272+ # TODO: should ArrowStringArrayNumpySemantics support sum?
273+ klass = TypeError
274+ msg = (
275+ "'ArrowStringArrayNumpySemantics' with dtype string does not "
276+ f"support operation '{ groupby_func_np .__name__ } '"
277+ )
278+
231279 _call_and_check (klass , msg , how , gb , groupby_func_np , ())
232280
233281
0 commit comments