@@ -157,7 +157,7 @@ def __init__(
157157
158158 for idx , rule in enumerate (self .split_rules ):
159159 if rule is ContinuousSplitRule :
160- self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .std (self .X [:, idx ]))
160+ self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .nanstd (self .X [:, idx ]))
161161
162162 init_mean = self .bart .Y .mean ()
163163 self .num_observations = self .X .shape [0 ]
@@ -700,7 +700,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
700700 if are_whole_number (array ):
701701 seen = []
702702 for idx , num in enumerate (array ):
703- if num in seen :
703+ if num in seen and not np . isnan ( num ) :
704704 array [idx ] = num + np .random .normal (0 , std / 12 )
705705 else :
706706 seen .append (num )
@@ -711,8 +711,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
711711@njit
712712def are_whole_number (array : npt .NDArray [np .float_ ]) -> np .bool_ :
713713 """Check if all values in array are whole numbers"""
714- new_array = np .mod (array , 1 )
715- return np .all (new_array == 0 )
714+ return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
716715
717716
718717def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
0 commit comments