Skip to content

Commit 47ff6ed

Browse files
committed
replicate test
1 parent 2a3faca commit 47ff6ed

File tree

1 file changed

+39
-55
lines changed

1 file changed

+39
-55
lines changed

src/max_min_by.rs

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -327,21 +327,6 @@ mod tests {
327327
mod max_by {
328328

329329
use super::*;
330-
async fn extract_string(df: prelude::DataFrame) -> error::Result<String> {
331-
let results = df.collect().await?;
332-
let col = results[0].column(0);
333-
let arr = col
334-
.as_any()
335-
.downcast_ref::<arrow::array::StringArray>()
336-
.unwrap();
337-
Ok(arr.value(0).to_string())
338-
}
339-
340-
fn ctx_max() -> error::Result<prelude::SessionContext> {
341-
let ctx = prelude::SessionContext::new();
342-
ctx.register_udaf(MaxByFunction::new().into());
343-
Ok(ctx)
344-
}
345330

346331
#[tokio::test]
347332
async fn test_max_by_string_int() -> error::Result<()> {
@@ -405,8 +390,7 @@ mod tests {
405390

406391
#[tokio::test]
407392
async fn test_max_by_ignores_nulls() -> error::Result<()> {
408-
let ctx = ctx_max()?;
409-
let sql = r#"
393+
let query = r#"
410394
SELECT max_by(v, k)
411395
FROM (
412396
VALUES
@@ -415,28 +399,27 @@ mod tests {
415399
('c', 2)
416400
) AS t(v, k)
417401
"#;
418-
let df = ctx.sql(sql).await?;
419-
let got = extract_string(df).await?;
420-
assert_eq!(got, "c", "max_by should ignore NULLs");
402+
let df =ctx()?.sql(&query).await?;
403+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
404+
assert_eq!(result, "c", "max_by should ignore NULLs");
421405
Ok(())
422406
}
423407

424408
#[tokio::test]
425409
async fn test_max_like_main_test() -> error::Result<()> {
426-
let ctx = ctx_max()?;
427-
let sql = r#"
410+
let query = r#"
428411
SELECT max_by(v, k)
429412
FROM (
430413
VALUES
431-
('a', 10),
432-
('b', 5),
433-
('c', 15),
434-
('d', 8)
414+
(1, 10),
415+
(2, 5),
416+
(3, 15),
417+
(4, 8)
435418
) AS t(v, k)
436419
"#;
437-
let df = ctx.sql(sql).await?;
438-
let got = extract_string(df).await?;
439-
assert_eq!(got, "c");
420+
let df = ctx()?.sql(&query).await?;
421+
let result = extract_single_value::<i64, arrow::array::Int64Array>(df).await?;
422+
assert_eq!(result, 3);
440423
Ok(())
441424
}
442425

@@ -452,21 +435,6 @@ mod tests {
452435
mod min_by {
453436

454437
use super::*;
455-
async fn extract_string(df: prelude::DataFrame) -> error::Result<String> {
456-
let results = df.collect().await?;
457-
let col = results[0].column(0);
458-
let arr = col
459-
.as_any()
460-
.downcast_ref::<arrow::array::StringArray>()
461-
.unwrap();
462-
Ok(arr.value(0).to_string())
463-
}
464-
465-
fn ctx_min() -> error::Result<prelude::SessionContext> {
466-
let ctx = prelude::SessionContext::new();
467-
ctx.register_udaf(MinByFunction::new().into());
468-
Ok(ctx)
469-
}
470438

471439
#[tokio::test]
472440
async fn test_min_by_string_int() -> error::Result<()> {
@@ -530,8 +498,7 @@ mod tests {
530498

531499
#[tokio::test]
532500
async fn test_min_by_ignores_nulls() -> error::Result<()> {
533-
let ctx = ctx_min()?;
534-
let sql = r#"
501+
let query = r#"
535502
SELECT min_by(v, k)
536503
FROM (
537504
VALUES
@@ -540,16 +507,15 @@ mod tests {
540507
('c', 2)
541508
) AS t(v, k)
542509
"#;
543-
let df = ctx.sql(sql).await?;
544-
let got = extract_string(df).await?;
545-
assert_eq!(got, "a", "max_by should ignore NULLs");
510+
let df = ctx()?.sql(&query).await?;
511+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
512+
assert_eq!(result, "a", "min_by should ignore NULLs");
546513
Ok(())
547514
}
548515

549516
#[tokio::test]
550-
async fn test_min_like_main_test() -> error::Result<()> {
551-
let ctx = ctx_min()?;
552-
let sql = r#"
517+
async fn test_min_like_main_test_str() -> error::Result<()> {
518+
let query = r#"
553519
SELECT min_by(v, k)
554520
FROM (
555521
VALUES
@@ -559,9 +525,27 @@ mod tests {
559525
('d', 8)
560526
) AS t(v, k)
561527
"#;
562-
let df = ctx.sql(sql).await?;
563-
let got = extract_string(df).await?;
564-
assert_eq!(got, "b");
528+
let df = ctx()?.sql(&query).await?;
529+
let result = extract_single_value::<String, arrow::array::StringArray>(df).await?;
530+
assert_eq!(result, "b");
531+
Ok(())
532+
}
533+
534+
#[tokio::test]
535+
async fn test_min_like_main_test_int() -> error::Result<()> {
536+
let query = r#"
537+
SELECT min_by(v, k)
538+
FROM (
539+
VALUES
540+
(1, 10),
541+
(2, 5),
542+
(3, 15),
543+
(4, 8)
544+
) AS t(v, k)
545+
"#;
546+
let df = ctx()?.sql(&query).await?;
547+
let result = extract_single_value::<i64, arrow::array::Int64Array>(df).await?;
548+
assert_eq!(result, 2);
565549
Ok(())
566550
}
567551

0 commit comments

Comments
 (0)