Skip to content

Commit 292eb95

Browse files
authored
Support using var/var_pop/stddev/stddev_pop in window expressions with custom frames (apache#4848)
* Wire up retract_batch for Stddev/StddevPop/Variance/VariancePop to * Add test for Stddev/StddevPop/Variance/VariancePop with window frame
1 parent 13fb42e commit 292eb95

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

datafusion/core/tests/sql/window.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,34 @@ async fn window_frame_rows_preceding() -> Result<()> {
524524
Ok(())
525525
}
526526

527+
#[tokio::test]
528+
async fn window_frame_rows_preceding_stddev_variance() -> Result<()> {
529+
let ctx = SessionContext::new();
530+
register_aggregate_csv(&ctx).await?;
531+
let sql = "SELECT \
532+
VAR(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
533+
VAR_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
534+
STDDEV(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
535+
STDDEV_POP(c4) OVER(ORDER BY c4 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)\
536+
FROM aggregate_test_100 \
537+
ORDER BY c9 \
538+
LIMIT 5";
539+
let actual = execute_to_batches(&ctx, sql).await;
540+
let expected = vec![
541+
"+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
542+
"| VARIANCE(aggregate_test_100.c4) | VARIANCEPOP(aggregate_test_100.c4) | STDDEV(aggregate_test_100.c4) | STDDEVPOP(aggregate_test_100.c4) |",
543+
"+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
544+
"| 46721.33333333174 | 31147.555555554496 | 216.15118166073427 | 176.4867007894773 |",
545+
"| 2639429.333333332 | 1759619.5555555548 | 1624.6320609089714 | 1326.5065229977404 |",
546+
"| 746202.3333333324 | 497468.2222222216 | 863.8300372951455 | 705.3142719541563 |",
547+
"| 768422.9999999981 | 512281.9999999988 | 876.5973990378925 | 715.7387791645767 |",
548+
"| 66526.3333333288 | 44350.88888888587 | 257.9269922542594 | 210.5965073045749 |",
549+
"+---------------------------------+------------------------------------+-------------------------------+----------------------------------+",
550+
];
551+
assert_batches_eq!(expected, &actual);
552+
Ok(())
553+
}
554+
527555
#[tokio::test]
528556
async fn window_frame_rows_preceding_with_partition_unique_order_by() -> Result<()> {
529557
let ctx = SessionContext::new();

datafusion/physical-expr/src/aggregate/stddev.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,10 @@ impl AggregateExpr for Stddev {
7373
Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
7474
}
7575

76+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
77+
Ok(Box::new(StddevAccumulator::try_new(StatsType::Sample)?))
78+
}
79+
7680
fn state_fields(&self) -> Result<Vec<Field>> {
7781
Ok(vec![
7882
Field::new(
@@ -128,6 +132,10 @@ impl AggregateExpr for StddevPop {
128132
Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
129133
}
130134

135+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
136+
Ok(Box::new(StddevAccumulator::try_new(StatsType::Population)?))
137+
}
138+
131139
fn state_fields(&self) -> Result<Vec<Field>> {
132140
Ok(vec![
133141
Field::new(
@@ -184,6 +192,10 @@ impl Accumulator for StddevAccumulator {
184192
self.variance.update_batch(values)
185193
}
186194

195+
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
196+
self.variance.retract_batch(values)
197+
}
198+
187199
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
188200
self.variance.merge_batch(states)
189201
}

datafusion/physical-expr/src/aggregate/variance.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ impl AggregateExpr for Variance {
7979
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
8080
}
8181

82+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
83+
Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
84+
}
85+
8286
fn state_fields(&self) -> Result<Vec<Field>> {
8387
Ok(vec![
8488
Field::new(
@@ -136,6 +140,12 @@ impl AggregateExpr for VariancePop {
136140
)?))
137141
}
138142

143+
fn create_sliding_accumulator(&self) -> Result<Box<dyn Accumulator>> {
144+
Ok(Box::new(VarianceAccumulator::try_new(
145+
StatsType::Population,
146+
)?))
147+
}
148+
139149
fn state_fields(&self) -> Result<Vec<Field>> {
140150
Ok(vec![
141151
Field::new(

0 commit comments

Comments
 (0)