@@ -22,7 +22,11 @@ use crate::aggregate::stats::StatsType;
2222use crate :: aggregate:: stddev:: StddevAccumulator ;
2323use crate :: expressions:: format_state_name;
2424use crate :: { AggregateExpr , PhysicalExpr } ;
25- use arrow:: { array:: ArrayRef , datatypes:: DataType , datatypes:: Field } ;
25+ use arrow:: {
26+ array:: ArrayRef ,
27+ compute:: { and, filter, is_not_null} ,
28+ datatypes:: { DataType , Field } ,
29+ } ;
2630use datafusion_common:: Result ;
2731use datafusion_common:: ScalarValue ;
2832use datafusion_expr:: Accumulator ;
@@ -145,14 +149,39 @@ impl Accumulator for CorrelationAccumulator {
145149 }
146150
147151 fn update_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
148- self . covar . update_batch ( values) ?;
152+ // TODO: null input skipping logic duplicated across Correlation
153+ // and its children accumulators.
154+ // This could be simplified by splitting up input filtering and
155+ // calculation logic in children accumulators, and calling only
156+ // calculation part from Correlation
157+ let values = if values[ 0 ] . null_count ( ) != 0 || values[ 1 ] . null_count ( ) != 0 {
158+ let mask = and ( & is_not_null ( & values[ 0 ] ) ?, & is_not_null ( & values[ 1 ] ) ?) ?;
159+ let values1 = filter ( & values[ 0 ] , & mask) ?;
160+ let values2 = filter ( & values[ 1 ] , & mask) ?;
161+
162+ vec ! [ values1, values2]
163+ } else {
164+ values. to_vec ( )
165+ } ;
166+
167+ self . covar . update_batch ( & values) ?;
149168 self . stddev1 . update_batch ( & values[ 0 ..1 ] ) ?;
150169 self . stddev2 . update_batch ( & values[ 1 ..2 ] ) ?;
151170 Ok ( ( ) )
152171 }
153172
154173 fn retract_batch ( & mut self , values : & [ ArrayRef ] ) -> Result < ( ) > {
155- self . covar . retract_batch ( values) ?;
174+ let values = if values[ 0 ] . null_count ( ) != 0 || values[ 1 ] . null_count ( ) != 0 {
175+ let mask = and ( & is_not_null ( & values[ 0 ] ) ?, & is_not_null ( & values[ 1 ] ) ?) ?;
176+ let values1 = filter ( & values[ 0 ] , & mask) ?;
177+ let values2 = filter ( & values[ 1 ] , & mask) ?;
178+
179+ vec ! [ values1, values2]
180+ } else {
181+ values. to_vec ( )
182+ } ;
183+
184+ self . covar . retract_batch ( & values) ?;
156185 self . stddev1 . retract_batch ( & values[ 0 ..1 ] ) ?;
157186 self . stddev2 . retract_batch ( & values[ 1 ..2 ] ) ?;
158187 Ok ( ( ) )
@@ -341,48 +370,44 @@ mod tests {
341370
342371 #[ test]
343372 fn correlation_i32_with_nulls_2 ( ) -> Result < ( ) > {
344- let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 1 ) , None , Some ( 3 ) ] ) ) ;
345- let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ Some ( 4 ) , Some ( 5 ) , Some ( 6 ) ] ) ) ;
346-
347- let schema = Schema :: new ( vec ! [
348- Field :: new( "a" , DataType :: Int32 , true ) ,
349- Field :: new( "b" , DataType :: Int32 , true ) ,
350- ] ) ;
351- let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ a, b] ) ?;
352-
353- let agg = Arc :: new ( Correlation :: new (
354- col ( "a" , & schema) ?,
355- col ( "b" , & schema) ?,
356- "bla" . to_string ( ) ,
357- DataType :: Float64 ,
358- ) ) ;
359- let actual = aggregate ( & batch, agg) ;
360- assert ! ( actual. is_err( ) ) ;
373+ let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [
374+ Some ( 1 ) ,
375+ None ,
376+ Some ( 2 ) ,
377+ Some ( 9 ) ,
378+ Some ( 3 ) ,
379+ ] ) ) ;
380+ let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [
381+ Some ( 4 ) ,
382+ Some ( 5 ) ,
383+ Some ( 5 ) ,
384+ None ,
385+ Some ( 6 ) ,
386+ ] ) ) ;
361387
362- Ok ( ( ) )
388+ generic_test_op2 ! (
389+ a,
390+ b,
391+ DataType :: Int32 ,
392+ DataType :: Int32 ,
393+ Correlation ,
394+ ScalarValue :: from( 1_f64 )
395+ )
363396 }
364397
365398 #[ test]
366399 fn correlation_i32_all_nulls ( ) -> Result < ( ) > {
367400 let a: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ None , None ] ) ) ;
368401 let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ None , None ] ) ) ;
369402
370- let schema = Schema :: new ( vec ! [
371- Field :: new( "a" , DataType :: Int32 , true ) ,
372- Field :: new( "b" , DataType :: Int32 , true ) ,
373- ] ) ;
374- let batch = RecordBatch :: try_new ( Arc :: new ( schema. clone ( ) ) , vec ! [ a, b] ) ?;
375-
376- let agg = Arc :: new ( Correlation :: new (
377- col ( "a" , & schema) ?,
378- col ( "b" , & schema) ?,
379- "bla" . to_string ( ) ,
380- DataType :: Float64 ,
381- ) ) ;
382- let actual = aggregate ( & batch, agg) ;
383- assert ! ( actual. is_err( ) ) ;
384-
385- Ok ( ( ) )
403+ generic_test_op2 ! (
404+ a,
405+ b,
406+ DataType :: Int32 ,
407+ DataType :: Int32 ,
408+ Correlation ,
409+ ScalarValue :: Float64 ( None )
410+ )
386411 }
387412
388413 #[ test]
0 commit comments