2323import java .util .Arrays ;
2424import java .util .Collection ;
2525import java .util .List ;
26- import java .util .Optional ;
2726import java .util .function .LongSupplier ;
2827
2928import org .jspecify .annotations .Nullable ;
3635import org .springframework .data .domain .SliceImpl ;
3736import org .springframework .data .domain .Sort ;
3837import org .springframework .data .domain .Vector ;
38+ import org .springframework .data .javapoet .LordOfTheStrings ;
39+ import org .springframework .data .javapoet .TypeNames ;
3940import org .springframework .data .jpa .repository .Modifying ;
4041import org .springframework .data .jpa .repository .NativeQuery ;
4142import org .springframework .data .jpa .repository .QueryHints ;
4546import org .springframework .data .jpa .repository .query .ParameterBinding ;
4647import org .springframework .data .jpa .repository .support .JpqlQueryTemplates ;
4748import org .springframework .data .repository .aot .generate .AotQueryMethodGenerationContext ;
48- import org .springframework .data .repository .query . ReturnedType ;
49+ import org .springframework .data .repository .aot . generate . MethodReturn ;
4950import org .springframework .data .support .PageableExecutionUtils ;
50- import org .springframework .data .util .ReflectionUtils ;
5151import org .springframework .javapoet .CodeBlock ;
5252import org .springframework .javapoet .CodeBlock .Builder ;
5353import org .springframework .javapoet .TypeName ;
@@ -88,7 +88,7 @@ static class QueryBlockBuilder {
8888 private final AotQueryMethodGenerationContext context ;
8989 private final JpaQueryMethod queryMethod ;
9090 private final String parameterNames ;
91- private String queryVariableName ;
91+ private final String queryVariableName ;
9292 private @ Nullable AotQueries queries ;
9393 private MergedAnnotation <QueryHints > queryHints = MergedAnnotation .missing ();
9494 private @ Nullable AotEntityGraph entityGraph ;
@@ -111,12 +111,6 @@ private QueryBlockBuilder(AotQueryMethodGenerationContext context, JpaQueryMetho
111111 }
112112 }
113113
114- public QueryBlockBuilder usingQueryVariableName (String queryVariableName ) {
115-
116- this .queryVariableName = context .localVariable (queryVariableName );
117- return this ;
118- }
119-
120114 public QueryBlockBuilder filter (AotQueries query ) {
121115 this .queries = query ;
122116 return this ;
@@ -160,7 +154,8 @@ public CodeBlock build() {
160154
161155 Assert .notNull (queries , "Queries must not be null" );
162156
163- boolean isProjecting = context .getReturnedType ().isProjecting ();
157+ MethodReturn methodReturn = context .getMethodReturn ();
158+ boolean isProjecting = methodReturn .isProjecting ();
164159
165160 String dynamicReturnType = null ;
166161 if (queryMethod .getParameters ().hasDynamicProjection ()) {
@@ -266,7 +261,8 @@ private CodeBlock applyRewrite(@Nullable String sort, @Nullable String dynamicRe
266261 dynamicReturnType );
267262 } else if (hasSort ) {
268263
269- Object actualReturnType = isProjecting ? context .getActualReturnTypeName () : context .getDomainType ();
264+ Object actualReturnType = isProjecting ? context .getMethodReturn ().getActualClassName ()
265+ : context .getDomainType ();
270266
271267 builder .addStatement ("$L = rewriteQuery($L, $L, $T.class)" , queryString , context .localVariable ("declaredQuery" ),
272268 sort , actualReturnType );
@@ -291,7 +287,6 @@ private CodeBlock applyLimits(boolean exists, @Nullable String pageable) {
291287
292288 if (exists ) {
293289 builder .addStatement ("$L.setMaxResults(1)" , queryVariableName );
294-
295290 return builder .build ();
296291 }
297292
@@ -434,7 +429,7 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
434429 @ Nullable String pageable ,
435430 @ Nullable Class <?> queryReturnType ) {
436431
437- ReturnedType returnedType = context .getReturnedType ();
432+ MethodReturn methodReturn = context .getMethodReturn ();
438433 Builder builder = CodeBlock .builder ();
439434 String queryStringNameToUse = queryStringName ;
440435
@@ -478,16 +473,14 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
478473 return builder .build ();
479474 }
480475
481- if (sq .hasConstructorExpressionOrDefaultProjection () && !count && returnedType .isProjecting ()
482- && returnedType .getReturnedType ().isInterface ()) {
476+ if (sq .hasConstructorExpressionOrDefaultProjection () && !count && methodReturn .isInterfaceProjection ()) {
483477 builder .addStatement ("$T $L = this.$L.createQuery($L)" , Query .class , queryVariableName ,
484478 context .fieldNameOf (EntityManager .class ), queryStringNameToUse );
485479 } else {
486480
487481 String createQueryMethod = query .isNative () ? "createNativeQuery" : "createQuery" ;
488482
489- if (!sq .hasConstructorExpressionOrDefaultProjection () && !count && returnedType .isProjecting ()
490- && returnedType .getReturnedType ().isInterface ()) {
483+ if (!sq .hasConstructorExpressionOrDefaultProjection () && !count && methodReturn .isInterfaceProjection ()) {
491484 builder .addStatement ("$T $L = this.$L.$L($L, $T.class)" , Query .class , queryVariableName ,
492485 context .fieldNameOf (EntityManager .class ), createQueryMethod , queryStringNameToUse , Tuple .class );
493486 } else {
@@ -501,8 +494,7 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
501494
502495 if (query instanceof NamedAotQuery nq ) {
503496
504- if (!count && !nq .hasConstructorExpressionOrDefaultProjection () && returnedType .isProjecting ()
505- && returnedType .getReturnedType ().isInterface ()) {
497+ if (!count && !nq .hasConstructorExpressionOrDefaultProjection () && methodReturn .isInterfaceProjection ()) {
506498 queryReturnType = Tuple .class ;
507499 }
508500
@@ -571,9 +563,9 @@ private CodeBlock applyEntityGraph(AotEntityGraph entityGraph, String queryVaria
571563 } else {
572564
573565 builder .addStatement ("$T<$T> $L = $L.createEntityGraph($T.class)" ,
574- jakarta .persistence .EntityGraph .class , context .getActualReturnType (). getType (),
566+ jakarta .persistence .EntityGraph .class , context .getDomainType (),
575567 context .localVariable ("entityGraph" ),
576- context .fieldNameOf (EntityManager .class ), context .getActualReturnType (). getType ());
568+ context .fieldNameOf (EntityManager .class ), context .getDomainType ());
577569
578570 for (String attributePath : entityGraph .attributePaths ()) {
579571
@@ -618,8 +610,8 @@ static class QueryExecutionBlockBuilder {
618610
619611 private final AotQueryMethodGenerationContext context ;
620612 private final JpaQueryMethod queryMethod ;
613+ private final String queryVariableName ;
621614 private @ Nullable AotQuery aotQuery ;
622- private String queryVariableName ;
623615 private @ Nullable String pageable ;
624616 private MergedAnnotation <Modifying > modifying = MergedAnnotation .missing ();
625617
@@ -631,12 +623,6 @@ private QueryExecutionBlockBuilder(AotQueryMethodGenerationContext context, JpaQ
631623 this .pageable = context .getPageableParameterName () != null ? context .localVariable ("pageable" ) : null ;
632624 }
633625
634- public QueryExecutionBlockBuilder referencing (String queryVariableName ) {
635-
636- this .queryVariableName = context .localVariable (queryVariableName );
637- return this ;
638- }
639-
640626 public QueryExecutionBlockBuilder query (AotQuery aotQuery ) {
641627
642628 this .aotQuery = aotQuery ;
@@ -658,20 +644,21 @@ public QueryExecutionBlockBuilder modifying(MergedAnnotation<Modifying> modifyin
658644 public CodeBlock build () {
659645
660646 Builder builder = CodeBlock .builder ();
661-
662- boolean isProjecting = !ObjectUtils .nullSafeEquals (context .getDomainType (), context .getActualReturnTypeName ());
663- TypeName actualReturnType = isProjecting ? context .getActualReturnTypeName ()
647+ MethodReturn methodReturn = context .getMethodReturn ();
648+ boolean isProjecting = methodReturn .isProjecting ()
649+ || !ObjectUtils .nullSafeEquals (context .getDomainType (), methodReturn .getActualReturnClass ())
650+ || StringUtils .hasText (context .getDynamicProjectionParameterName ());
651+ TypeName typeToRead = isProjecting ? methodReturn .getActualTypeName ()
664652 : TypeName .get (context .getDomainType ());
665653 builder .add ("\n " );
666654
667- Class <?> methodReturnType = context .getMethod ().getReturnType ();
668655 if (modifying .isPresent ()) {
669656
670657 if (modifying .getBoolean ("flushAutomatically" )) {
671658 builder .addStatement ("this.$L.flush()" , context .fieldNameOf (EntityManager .class ));
672659 }
673660
674- Class <?> returnType = methodReturnType ;
661+ Class <?> returnType = methodReturn . toClass () ;
675662
676663 if (returnsModifying (returnType )) {
677664 builder .addStatement ("int $L = $L.executeUpdate()" , context .localVariable ("result" ), queryVariableName );
@@ -694,15 +681,13 @@ public CodeBlock build() {
694681 return builder .build ();
695682 }
696683
697- TypeName queryResultType = TypeName .get (context .getActualReturnType ().toClass ());
698-
699684 if (aotQuery != null && aotQuery .isDelete ()) {
700685
701686 builder .addStatement ("$T $L = $L.getResultList()" , List .class ,
702687 context .localVariable ("resultList" ), queryVariableName );
703688
704- boolean returnCount = ClassUtils .isAssignable (Number .class , methodReturnType );
705- boolean simpleBatch = returnCount || ReflectionUtils .isVoid (methodReturnType );
689+ boolean returnCount = ClassUtils .isAssignable (Number .class , methodReturn . toClass () );
690+ boolean simpleBatch = returnCount || methodReturn .isVoid ();
706691 boolean collectionQuery = queryMethod .isCollectionQuery ();
707692
708693 if (!simpleBatch && !collectionQuery ) {
@@ -712,9 +697,6 @@ public CodeBlock build() {
712697 IncorrectResultSizeDataAccessException .class ,
713698 "Delete query returned more than one element: expected 1, actual " , context .localVariable ("resultList" ));
714699 builder .endControlFlow ();
715-
716- builder .addStatement ("$L.forEach($L::remove)" , context .localVariable ("resultList" ),
717- context .fieldNameOf (EntityManager .class ));
718700 }
719701
720702 builder .addStatement ("$L.forEach($L::remove)" , context .localVariable ("resultList" ),
@@ -724,40 +706,52 @@ public CodeBlock build() {
724706 builder .addStatement ("return ($T) $L" , List .class , context .localVariable ("resultList" ));
725707
726708 } else if (returnCount ) {
727- builder .addStatement ("return $T.valueOf($L.size())" , methodReturnType ,
709+ builder .addStatement ("return $T.valueOf($L.size())" , methodReturn . getActualClassName () ,
728710 context .localVariable ("resultList" ));
729711 } else {
730712
731- if (Optional .class .isAssignableFrom (methodReturnType )) {
732- builder .addStatement ("return ($1T) $1T.ofNullable($2L.isEmpty() ? null : $2L.iterator().next())" ,
733- Optional .class , context .localVariable ("resultList" ));
734- } else {
735- builder .addStatement ("return ($1T) ($2L.isEmpty() ? null : $2L.iterator().next())" , actualReturnType ,
736- context .localVariable ("resultList" ));
737- }
713+ builder .addStatement (LordOfTheStrings .returning (methodReturn .toClass ())
714+ .optional ("($1T) ($2L.isEmpty() ? null : $2L.iterator().next())" , typeToRead ,
715+ context .localVariable ("resultList" )) //
716+ .build ());
738717 }
739718 } else if (aotQuery != null && aotQuery .isExists ()) {
740719 builder .addStatement ("return !$L.getResultList().isEmpty()" , queryVariableName );
741720 } else if (aotQuery != null ) {
742721
743- if (context .getReturnedType ().isProjecting ()) {
722+ if (isProjecting ) {
723+
724+ TypeName returnType = TypeNames .typeNameOrWrapper (methodReturn .getActualType ());
725+ CodeBlock convertTo ;
726+ if (StringUtils .hasText (context .getDynamicProjectionParameterName ())) {
727+ convertTo = CodeBlock .of ("$L" , context .getDynamicProjectionParameterName ());
728+ } else {
729+
730+ if (methodReturn .isArray () && methodReturn .getActualType ().toClass ().equals (byte .class )) {
731+ returnType = TypeName .get (byte [].class );
732+ convertTo = CodeBlock .of ("$T.class" , returnType );
733+ } else {
734+ convertTo = CodeBlock .of ("$T.class" , TypeNames .classNameOrWrapper (methodReturn .getActualType ()));
735+ }
736+ }
744737
745738 if (queryMethod .isCollectionQuery ()) {
746- builder .addStatement ("return ($T) convertMany($L.getResultList(), $L, $T.class)" ,
747- context . getReturnTypeName (), queryVariableName , aotQuery .isNative (), queryResultType );
739+ builder .addStatement ("return ($T) convertMany($L.getResultList(), $L, $L)" , methodReturn . getTypeName () ,
740+ queryVariableName , aotQuery .isNative (), convertTo );
748741 } else if (queryMethod .isStreamQuery ()) {
749- builder .addStatement ("return ($T) convertMany($L.getResultStream(), $L, $T.class)" ,
750- context . getReturnTypeName (), queryVariableName , aotQuery .isNative (), queryResultType );
742+ builder .addStatement ("return ($T) convertMany($L.getResultStream(), $L, $L)" , methodReturn . getTypeName () ,
743+ queryVariableName , aotQuery .isNative (), convertTo );
751744 } else if (queryMethod .isPageQuery ()) {
752745 builder .addStatement (
753- "return $T.getPage(($T<$T>) convertMany($L.getResultList(), $L, $T.class ), $L, $L)" ,
754- PageableExecutionUtils .class , List .class , actualReturnType , queryVariableName , aotQuery . isNative ( ),
755- queryResultType , pageable , context .localVariable ("countAll" ));
746+ "return $T.getPage(($T<$T>) convertMany($L.getResultList(), $L, $L ), $L, $L)" ,
747+ PageableExecutionUtils .class , List .class , TypeNames . typeNameOrWrapper ( methodReturn . getActualType () ),
748+ queryVariableName , aotQuery . isNative (), convertTo , pageable , context .localVariable ("countAll" ));
756749 } else if (queryMethod .isSliceQuery ()) {
757- builder .addStatement ("$T<$T> $L = ($T<$T>) convertMany($L.getResultList(), $L, $T.class)" , List .class ,
758- actualReturnType , context .localVariable ("resultList" ), List .class , actualReturnType , queryVariableName ,
750+ builder .addStatement ("$T<$T> $L = ($T<$T>) convertMany($L.getResultList(), $L, $L)" , List .class ,
751+ TypeNames .typeNameOrWrapper (methodReturn .getActualType ()), context .localVariable ("resultList" ),
752+ List .class , typeToRead , queryVariableName ,
759753 aotQuery .isNative (),
760- queryResultType );
754+ convertTo );
761755 builder .addStatement ("boolean $L = $L.isPaged() && $L.size() > $L.getPageSize()" ,
762756 context .localVariable ("hasNext" ), pageable , context .localVariable ("resultList" ), pageable );
763757 builder .addStatement (
@@ -766,27 +760,24 @@ public CodeBlock build() {
766760 pageable , context .localVariable ("resultList" ), pageable , context .localVariable ("hasNext" ));
767761 } else {
768762
769- if (Optional .class .isAssignableFrom (context .getReturnType ().toClass ())) {
770- builder .addStatement ("return $T.ofNullable(($T) convertOne($L.getSingleResultOrNull(), $L, $T.class))" ,
771- Optional .class , actualReturnType , queryVariableName , aotQuery .isNative (), queryResultType );
772- } else {
773- builder .addStatement ("return ($T) convertOne($L.getSingleResultOrNull(), $L, $T.class)" ,
774- context .getReturnTypeName (), queryVariableName , aotQuery .isNative (), queryResultType );
775- }
763+ builder .addStatement (LordOfTheStrings .returning (methodReturn .toClass ())
764+ .optional ("($T) convertOne($L.getSingleResultOrNull(), $L, $L)" , returnType , queryVariableName ,
765+ aotQuery .isNative (), convertTo ) //
766+ .build ());
776767 }
777768
778769 } else {
779770
780771 if (queryMethod .isCollectionQuery ()) {
781- builder .addStatement ("return ($T) $L.getResultList()" , context . getReturnTypeName (), queryVariableName );
772+ builder .addStatement ("return ($T) $L.getResultList()" , methodReturn . getTypeName (), queryVariableName );
782773 } else if (queryMethod .isStreamQuery ()) {
783- builder .addStatement ("return ($T) $L.getResultStream()" , context . getReturnTypeName (), queryVariableName );
774+ builder .addStatement ("return ($T) $L.getResultStream()" , methodReturn . getTypeName (), queryVariableName );
784775 } else if (queryMethod .isPageQuery ()) {
785776 builder .addStatement ("return $T.getPage(($T<$T>) $L.getResultList(), $L, $L)" ,
786- PageableExecutionUtils .class , List .class , actualReturnType , queryVariableName ,
777+ PageableExecutionUtils .class , List .class , typeToRead , queryVariableName ,
787778 pageable , context .localVariable ("countAll" ));
788779 } else if (queryMethod .isSliceQuery ()) {
789- builder .addStatement ("$T<$T> $L = $L.getResultList()" , List .class , actualReturnType ,
780+ builder .addStatement ("$T<$T> $L = $L.getResultList()" , List .class , typeToRead ,
790781 context .localVariable ("resultList" ), queryVariableName );
791782 builder .addStatement ("boolean $L = $L.isPaged() && $L.size() > $L.getPageSize()" ,
792783 context .localVariable ("hasNext" ), pageable , context .localVariable ("resultList" ), pageable );
@@ -796,15 +787,11 @@ public CodeBlock build() {
796787 pageable , context .localVariable ("resultList" ), pageable , context .localVariable ("hasNext" ));
797788 } else {
798789
799- if (Optional .class .isAssignableFrom (context .getReturnType ().toClass ())) {
800- builder .addStatement ("return $T.ofNullable(($T) convertOne($L.getSingleResultOrNull(), $L, $T.class))" ,
801- Optional .class , actualReturnType , queryVariableName , aotQuery .isNative (),
802- queryResultType );
803- } else {
804- builder .addStatement ("return ($T) convertOne($L.getSingleResultOrNull(), $L, $T.class)" ,
805- context .getReturnTypeName (), queryVariableName , aotQuery .isNative (),
806- context .getReturnType ().toClass ());
807- }
790+ builder .addStatement (LordOfTheStrings .returning (methodReturn .toClass ())
791+ .optional ("($T) convertOne($L.getSingleResultOrNull(), $L, $T.class)" ,
792+ TypeNames .typeNameOrWrapper (methodReturn .getActualType ()), queryVariableName , aotQuery .isNative (),
793+ TypeNames .classNameOrWrapper (methodReturn .getActualType ())) //
794+ .build ());
808795 }
809796 }
810797 }
@@ -820,4 +807,5 @@ public static boolean returnsModifying(Class<?> returnType) {
820807
821808 }
822809
810+
823811}
0 commit comments