@@ -315,6 +315,8 @@ class BaseStringArray(ExtensionArray):
315315 Mixin class for StringArray, ArrowStringArray.
316316 """
317317
318+ dtype : StringDtype
319+
318320 @doc (ExtensionArray .tolist )
319321 def tolist (self ):
320322 if self .ndim > 1 :
@@ -328,6 +330,37 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
328330 raise ValueError
329331 return cls ._from_sequence (scalars , dtype = dtype )
330332
333+ def _str_map_str_or_object (
334+ self ,
335+ dtype ,
336+ na_value ,
337+ arr : np .ndarray ,
338+ f ,
339+ mask : npt .NDArray [np .bool_ ],
340+ convert : bool ,
341+ ):
342+ # _str_map helper for case where dtype is either string dtype or object
343+ if is_string_dtype (dtype ) and not is_object_dtype (dtype ):
344+ # i.e. StringDtype
345+ result = lib .map_infer_mask (
346+ arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
347+ )
348+ if self .dtype .storage == "pyarrow" :
349+ import pyarrow as pa
350+
351+ result = pa .array (
352+ result , mask = mask , type = pa .large_string (), from_pandas = True
353+ )
354+ # error: Too many arguments for "BaseStringArray"
355+ return type (self )(result ) # type: ignore[call-arg]
356+
357+ else :
358+ # This is when the result type is object. We reach this when
359+ # -> We know the result type is truly object (e.g. .encode returns bytes
360+ # or .findall returns a list).
361+ # -> We don't know the result type. E.g. `.get` can return anything.
362+ return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
363+
331364
332365# error: Definition of "_concat_same_type" in base class "NDArrayBacked" is
333366# incompatible with definition in base class "ExtensionArray"
@@ -682,9 +715,53 @@ def _cmp_method(self, other, op):
682715 # base class "NumpyExtensionArray" defined the type as "float")
683716 _str_na_value = libmissing .NA # type: ignore[assignment]
684717
718+ def _str_map_nan_semantics (
719+ self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
720+ ):
721+ if dtype is None :
722+ dtype = self .dtype
723+ if na_value is None :
724+ na_value = self .dtype .na_value
725+
726+ mask = isna (self )
727+ arr = np .asarray (self )
728+ convert = convert and not np .all (mask )
729+
730+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
731+ na_value_is_na = isna (na_value )
732+ if na_value_is_na :
733+ if is_integer_dtype (dtype ):
734+ na_value = 0
735+ else :
736+ na_value = True
737+
738+ result = lib .map_infer_mask (
739+ arr ,
740+ f ,
741+ mask .view ("uint8" ),
742+ convert = False ,
743+ na_value = na_value ,
744+ dtype = np .dtype (cast (type , dtype )),
745+ )
746+ if na_value_is_na and mask .any ():
747+ if is_integer_dtype (dtype ):
748+ result = result .astype ("float64" )
749+ else :
750+ result = result .astype ("object" )
751+ result [mask ] = np .nan
752+ return result
753+
754+ else :
755+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
756+
685757 def _str_map (
686758 self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
687759 ):
760+ if self .dtype .na_value is np .nan :
761+ return self ._str_map_nan_semantics (
762+ f , na_value = na_value , dtype = dtype , convert = convert
763+ )
764+
688765 from pandas .arrays import BooleanArray
689766
690767 if dtype is None :
@@ -724,18 +801,8 @@ def _str_map(
724801
725802 return constructor (result , mask )
726803
727- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
728- # i.e. StringDtype
729- result = lib .map_infer_mask (
730- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
731- )
732- return StringArray (result )
733804 else :
734- # This is when the result type is object. We reach this when
735- # -> We know the result type is truly object (e.g. .encode returns bytes
736- # or .findall returns a list).
737- # -> We don't know the result type. E.g. `.get` can return anything.
738- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
805+ return self ._str_map_str_or_object (dtype , na_value , arr , f , mask , convert )
739806
740807
741808class StringArrayNumpySemantics (StringArray ):
@@ -802,52 +869,3 @@ def value_counts(self, dropna: bool = True) -> Series:
802869 # ------------------------------------------------------------------------
803870 # String methods interface
804871 _str_na_value = np .nan
805-
806- def _str_map (
807- self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
808- ):
809- if dtype is None :
810- dtype = self .dtype
811- if na_value is None :
812- na_value = self .dtype .na_value
813-
814- mask = isna (self )
815- arr = np .asarray (self )
816- convert = convert and not np .all (mask )
817-
818- if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
819- na_value_is_na = isna (na_value )
820- if na_value_is_na :
821- if is_integer_dtype (dtype ):
822- na_value = 0
823- else :
824- na_value = True
825-
826- result = lib .map_infer_mask (
827- arr ,
828- f ,
829- mask .view ("uint8" ),
830- convert = False ,
831- na_value = na_value ,
832- dtype = np .dtype (cast (type , dtype )),
833- )
834- if na_value_is_na and mask .any ():
835- if is_integer_dtype (dtype ):
836- result = result .astype ("float64" )
837- else :
838- result = result .astype ("object" )
839- result [mask ] = np .nan
840- return result
841-
842- elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
843- # i.e. StringDtype
844- result = lib .map_infer_mask (
845- arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
846- )
847- return type (self )(result )
848- else :
849- # This is when the result type is object. We reach this when
850- # -> We know the result type is truly object (e.g. .encode returns bytes
851- # or .findall returns a list).
852- # -> We don't know the result type. E.g. `.get` can return anything.
853- return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
0 commit comments