@@ -327,21 +327,6 @@ mod tests {
327327 mod max_by {
328328
329329 use super :: * ;
330- async fn extract_string ( df : prelude:: DataFrame ) -> error:: Result < String > {
331- let results = df. collect ( ) . await ?;
332- let col = results[ 0 ] . column ( 0 ) ;
333- let arr = col
334- . as_any ( )
335- . downcast_ref :: < arrow:: array:: StringArray > ( )
336- . unwrap ( ) ;
337- Ok ( arr. value ( 0 ) . to_string ( ) )
338- }
339-
340- fn ctx_max ( ) -> error:: Result < prelude:: SessionContext > {
341- let ctx = prelude:: SessionContext :: new ( ) ;
342- ctx. register_udaf ( MaxByFunction :: new ( ) . into ( ) ) ;
343- Ok ( ctx)
344- }
345330
346331 #[ tokio:: test]
347332 async fn test_max_by_string_int ( ) -> error:: Result < ( ) > {
@@ -405,8 +390,7 @@ mod tests {
405390
406391 #[ tokio:: test]
407392 async fn test_max_by_ignores_nulls ( ) -> error:: Result < ( ) > {
408- let ctx = ctx_max ( ) ?;
409- let sql = r#"
393+ let query = r#"
410394 SELECT max_by(v, k)
411395 FROM (
412396 VALUES
@@ -415,28 +399,27 @@ mod tests {
415399 ('c', 2)
416400 ) AS t(v, k)
417401 "# ;
418- let df = ctx. sql ( sql ) . await ?;
419- let got = extract_string ( df) . await ?;
420- assert_eq ! ( got , "c" , "max_by should ignore NULLs" ) ;
402+ let df =ctx ( ) ? . sql ( & query ) . await ?;
403+ let result = extract_single_value :: < String , arrow :: array :: StringArray > ( df) . await ?;
404+ assert_eq ! ( result , "c" , "max_by should ignore NULLs" ) ;
421405 Ok ( ( ) )
422406 }
423407
424408 #[ tokio:: test]
425409 async fn test_max_like_main_test ( ) -> error:: Result < ( ) > {
426- let ctx = ctx_max ( ) ?;
427- let sql = r#"
410+ let query = r#"
428411 SELECT max_by(v, k)
429412 FROM (
430413 VALUES
431- ('a' , 10),
432- ('b' , 5),
433- ('c' , 15),
434- ('d' , 8)
414+ (1 , 10),
415+ (2 , 5),
416+ (3 , 15),
417+ (4 , 8)
435418 ) AS t(v, k)
436419 "# ;
437- let df = ctx. sql ( sql ) . await ?;
438- let got = extract_string ( df) . await ?;
439- assert_eq ! ( got , "c" ) ;
420+ let df = ctx ( ) ? . sql ( & query ) . await ?;
421+ let result = extract_single_value :: < i64 , arrow :: array :: Int64Array > ( df) . await ?;
422+ assert_eq ! ( result , 3 ) ;
440423 Ok ( ( ) )
441424 }
442425
@@ -452,21 +435,6 @@ mod tests {
452435 mod min_by {
453436
454437 use super :: * ;
455- async fn extract_string ( df : prelude:: DataFrame ) -> error:: Result < String > {
456- let results = df. collect ( ) . await ?;
457- let col = results[ 0 ] . column ( 0 ) ;
458- let arr = col
459- . as_any ( )
460- . downcast_ref :: < arrow:: array:: StringArray > ( )
461- . unwrap ( ) ;
462- Ok ( arr. value ( 0 ) . to_string ( ) )
463- }
464-
465- fn ctx_min ( ) -> error:: Result < prelude:: SessionContext > {
466- let ctx = prelude:: SessionContext :: new ( ) ;
467- ctx. register_udaf ( MinByFunction :: new ( ) . into ( ) ) ;
468- Ok ( ctx)
469- }
470438
471439 #[ tokio:: test]
472440 async fn test_min_by_string_int ( ) -> error:: Result < ( ) > {
@@ -530,8 +498,7 @@ mod tests {
530498
531499 #[ tokio:: test]
532500 async fn test_min_by_ignores_nulls ( ) -> error:: Result < ( ) > {
533- let ctx = ctx_min ( ) ?;
534- let sql = r#"
501+ let query = r#"
535502 SELECT min_by(v, k)
536503 FROM (
537504 VALUES
@@ -540,16 +507,15 @@ mod tests {
540507 ('c', 2)
541508 ) AS t(v, k)
542509 "# ;
543- let df = ctx. sql ( sql ) . await ?;
544- let got = extract_string ( df) . await ?;
545- assert_eq ! ( got , "a" , "max_by should ignore NULLs" ) ;
510+ let df = ctx ( ) ? . sql ( & query ) . await ?;
511+ let result = extract_single_value :: < String , arrow :: array :: StringArray > ( df) . await ?;
512+ assert_eq ! ( result , "a" , "min_by should ignore NULLs" ) ;
546513 Ok ( ( ) )
547514 }
548515
549516 #[ tokio:: test]
550- async fn test_min_like_main_test ( ) -> error:: Result < ( ) > {
551- let ctx = ctx_min ( ) ?;
552- let sql = r#"
517+ async fn test_min_like_main_test_str ( ) -> error:: Result < ( ) > {
518+ let query = r#"
553519 SELECT min_by(v, k)
554520 FROM (
555521 VALUES
@@ -559,9 +525,27 @@ mod tests {
559525 ('d', 8)
560526 ) AS t(v, k)
561527 "# ;
562- let df = ctx. sql ( sql) . await ?;
563- let got = extract_string ( df) . await ?;
564- assert_eq ! ( got, "b" ) ;
528+ let df = ctx ( ) ?. sql ( & query) . await ?;
529+ let result = extract_single_value :: < String , arrow:: array:: StringArray > ( df) . await ?;
530+ assert_eq ! ( result, "b" ) ;
531+ Ok ( ( ) )
532+ }
533+
534+ #[ tokio:: test]
535+ async fn test_min_like_main_test_int ( ) -> error:: Result < ( ) > {
536+ let query = r#"
537+ SELECT min_by(v, k)
538+ FROM (
539+ VALUES
540+ (1, 10),
541+ (2, 5),
542+ (3, 15),
543+ (4, 8)
544+ ) AS t(v, k)
545+ "# ;
546+ let df = ctx ( ) ?. sql ( & query) . await ?;
547+ let result = extract_single_value :: < i64 , arrow:: array:: Int64Array > ( df) . await ?;
548+ assert_eq ! ( result, 2 ) ;
565549 Ok ( ( ) )
566550 }
567551
0 commit comments