Skip to content

Commit 6a26a73

Browse files
committed
Refine JavaPoet usage.
See #4007
1 parent de7c568 commit 6a26a73

File tree

2 files changed

+69
-81
lines changed

2 files changed

+69
-81
lines changed

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaCodeBlocks.java

Lines changed: 68 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import java.util.Arrays;
2424
import java.util.Collection;
2525
import java.util.List;
26-
import java.util.Optional;
2726
import java.util.function.LongSupplier;
2827

2928
import org.jspecify.annotations.Nullable;
@@ -36,6 +35,8 @@
3635
import org.springframework.data.domain.SliceImpl;
3736
import org.springframework.data.domain.Sort;
3837
import org.springframework.data.domain.Vector;
38+
import org.springframework.data.javapoet.LordOfTheStrings;
39+
import org.springframework.data.javapoet.TypeNames;
3940
import org.springframework.data.jpa.repository.Modifying;
4041
import org.springframework.data.jpa.repository.NativeQuery;
4142
import org.springframework.data.jpa.repository.QueryHints;
@@ -45,9 +46,8 @@
4546
import org.springframework.data.jpa.repository.query.ParameterBinding;
4647
import org.springframework.data.jpa.repository.support.JpqlQueryTemplates;
4748
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
48-
import org.springframework.data.repository.query.ReturnedType;
49+
import org.springframework.data.repository.aot.generate.MethodReturn;
4950
import org.springframework.data.support.PageableExecutionUtils;
50-
import org.springframework.data.util.ReflectionUtils;
5151
import org.springframework.javapoet.CodeBlock;
5252
import org.springframework.javapoet.CodeBlock.Builder;
5353
import org.springframework.javapoet.TypeName;
@@ -88,7 +88,7 @@ static class QueryBlockBuilder {
8888
private final AotQueryMethodGenerationContext context;
8989
private final JpaQueryMethod queryMethod;
9090
private final String parameterNames;
91-
private String queryVariableName;
91+
private final String queryVariableName;
9292
private @Nullable AotQueries queries;
9393
private MergedAnnotation<QueryHints> queryHints = MergedAnnotation.missing();
9494
private @Nullable AotEntityGraph entityGraph;
@@ -111,12 +111,6 @@ private QueryBlockBuilder(AotQueryMethodGenerationContext context, JpaQueryMetho
111111
}
112112
}
113113

114-
public QueryBlockBuilder usingQueryVariableName(String queryVariableName) {
115-
116-
this.queryVariableName = context.localVariable(queryVariableName);
117-
return this;
118-
}
119-
120114
public QueryBlockBuilder filter(AotQueries query) {
121115
this.queries = query;
122116
return this;
@@ -160,7 +154,8 @@ public CodeBlock build() {
160154

161155
Assert.notNull(queries, "Queries must not be null");
162156

163-
boolean isProjecting = context.getReturnedType().isProjecting();
157+
MethodReturn methodReturn = context.getMethodReturn();
158+
boolean isProjecting = methodReturn.isProjecting();
164159

165160
String dynamicReturnType = null;
166161
if (queryMethod.getParameters().hasDynamicProjection()) {
@@ -266,7 +261,8 @@ private CodeBlock applyRewrite(@Nullable String sort, @Nullable String dynamicRe
266261
dynamicReturnType);
267262
} else if (hasSort) {
268263

269-
Object actualReturnType = isProjecting ? context.getActualReturnTypeName() : context.getDomainType();
264+
Object actualReturnType = isProjecting ? context.getMethodReturn().getActualClassName()
265+
: context.getDomainType();
270266

271267
builder.addStatement("$L = rewriteQuery($L, $L, $T.class)", queryString, context.localVariable("declaredQuery"),
272268
sort, actualReturnType);
@@ -291,7 +287,6 @@ private CodeBlock applyLimits(boolean exists, @Nullable String pageable) {
291287

292288
if (exists) {
293289
builder.addStatement("$L.setMaxResults(1)", queryVariableName);
294-
295290
return builder.build();
296291
}
297292

@@ -434,7 +429,7 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
434429
@Nullable String pageable,
435430
@Nullable Class<?> queryReturnType) {
436431

437-
ReturnedType returnedType = context.getReturnedType();
432+
MethodReturn methodReturn = context.getMethodReturn();
438433
Builder builder = CodeBlock.builder();
439434
String queryStringNameToUse = queryStringName;
440435

@@ -478,16 +473,14 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
478473
return builder.build();
479474
}
480475

481-
if (sq.hasConstructorExpressionOrDefaultProjection() && !count && returnedType.isProjecting()
482-
&& returnedType.getReturnedType().isInterface()) {
476+
if (sq.hasConstructorExpressionOrDefaultProjection() && !count && methodReturn.isInterfaceProjection()) {
483477
builder.addStatement("$T $L = this.$L.createQuery($L)", Query.class, queryVariableName,
484478
context.fieldNameOf(EntityManager.class), queryStringNameToUse);
485479
} else {
486480

487481
String createQueryMethod = query.isNative() ? "createNativeQuery" : "createQuery";
488482

489-
if (!sq.hasConstructorExpressionOrDefaultProjection() && !count && returnedType.isProjecting()
490-
&& returnedType.getReturnedType().isInterface()) {
483+
if (!sq.hasConstructorExpressionOrDefaultProjection() && !count && methodReturn.isInterfaceProjection()) {
491484
builder.addStatement("$T $L = this.$L.$L($L, $T.class)", Query.class, queryVariableName,
492485
context.fieldNameOf(EntityManager.class), createQueryMethod, queryStringNameToUse, Tuple.class);
493486
} else {
@@ -501,8 +494,7 @@ private CodeBlock doCreateQuery(boolean count, String queryVariableName,
501494

502495
if (query instanceof NamedAotQuery nq) {
503496

504-
if (!count && !nq.hasConstructorExpressionOrDefaultProjection() && returnedType.isProjecting()
505-
&& returnedType.getReturnedType().isInterface()) {
497+
if (!count && !nq.hasConstructorExpressionOrDefaultProjection() && methodReturn.isInterfaceProjection()) {
506498
queryReturnType = Tuple.class;
507499
}
508500

@@ -571,9 +563,9 @@ private CodeBlock applyEntityGraph(AotEntityGraph entityGraph, String queryVaria
571563
} else {
572564

573565
builder.addStatement("$T<$T> $L = $L.createEntityGraph($T.class)",
574-
jakarta.persistence.EntityGraph.class, context.getActualReturnType().getType(),
566+
jakarta.persistence.EntityGraph.class, context.getDomainType(),
575567
context.localVariable("entityGraph"),
576-
context.fieldNameOf(EntityManager.class), context.getActualReturnType().getType());
568+
context.fieldNameOf(EntityManager.class), context.getDomainType());
577569

578570
for (String attributePath : entityGraph.attributePaths()) {
579571

@@ -618,8 +610,8 @@ static class QueryExecutionBlockBuilder {
618610

619611
private final AotQueryMethodGenerationContext context;
620612
private final JpaQueryMethod queryMethod;
613+
private final String queryVariableName;
621614
private @Nullable AotQuery aotQuery;
622-
private String queryVariableName;
623615
private @Nullable String pageable;
624616
private MergedAnnotation<Modifying> modifying = MergedAnnotation.missing();
625617

@@ -631,12 +623,6 @@ private QueryExecutionBlockBuilder(AotQueryMethodGenerationContext context, JpaQ
631623
this.pageable = context.getPageableParameterName() != null ? context.localVariable("pageable") : null;
632624
}
633625

634-
public QueryExecutionBlockBuilder referencing(String queryVariableName) {
635-
636-
this.queryVariableName = context.localVariable(queryVariableName);
637-
return this;
638-
}
639-
640626
public QueryExecutionBlockBuilder query(AotQuery aotQuery) {
641627

642628
this.aotQuery = aotQuery;
@@ -658,20 +644,21 @@ public QueryExecutionBlockBuilder modifying(MergedAnnotation<Modifying> modifyin
658644
public CodeBlock build() {
659645

660646
Builder builder = CodeBlock.builder();
661-
662-
boolean isProjecting = !ObjectUtils.nullSafeEquals(context.getDomainType(), context.getActualReturnTypeName());
663-
TypeName actualReturnType = isProjecting ? context.getActualReturnTypeName()
647+
MethodReturn methodReturn = context.getMethodReturn();
648+
boolean isProjecting = methodReturn.isProjecting()
649+
|| !ObjectUtils.nullSafeEquals(context.getDomainType(), methodReturn.getActualReturnClass())
650+
|| StringUtils.hasText(context.getDynamicProjectionParameterName());
651+
TypeName typeToRead = isProjecting ? methodReturn.getActualTypeName()
664652
: TypeName.get(context.getDomainType());
665653
builder.add("\n");
666654

667-
Class<?> methodReturnType = context.getMethod().getReturnType();
668655
if (modifying.isPresent()) {
669656

670657
if (modifying.getBoolean("flushAutomatically")) {
671658
builder.addStatement("this.$L.flush()", context.fieldNameOf(EntityManager.class));
672659
}
673660

674-
Class<?> returnType = methodReturnType;
661+
Class<?> returnType = methodReturn.toClass();
675662

676663
if (returnsModifying(returnType)) {
677664
builder.addStatement("int $L = $L.executeUpdate()", context.localVariable("result"), queryVariableName);
@@ -694,15 +681,13 @@ public CodeBlock build() {
694681
return builder.build();
695682
}
696683

697-
TypeName queryResultType = TypeName.get(context.getActualReturnType().toClass());
698-
699684
if (aotQuery != null && aotQuery.isDelete()) {
700685

701686
builder.addStatement("$T $L = $L.getResultList()", List.class,
702687
context.localVariable("resultList"), queryVariableName);
703688

704-
boolean returnCount = ClassUtils.isAssignable(Number.class, methodReturnType);
705-
boolean simpleBatch = returnCount || ReflectionUtils.isVoid(methodReturnType);
689+
boolean returnCount = ClassUtils.isAssignable(Number.class, methodReturn.toClass());
690+
boolean simpleBatch = returnCount || methodReturn.isVoid();
706691
boolean collectionQuery = queryMethod.isCollectionQuery();
707692

708693
if (!simpleBatch && !collectionQuery) {
@@ -712,9 +697,6 @@ public CodeBlock build() {
712697
IncorrectResultSizeDataAccessException.class,
713698
"Delete query returned more than one element: expected 1, actual ", context.localVariable("resultList"));
714699
builder.endControlFlow();
715-
716-
builder.addStatement("$L.forEach($L::remove)", context.localVariable("resultList"),
717-
context.fieldNameOf(EntityManager.class));
718700
}
719701

720702
builder.addStatement("$L.forEach($L::remove)", context.localVariable("resultList"),
@@ -724,40 +706,52 @@ public CodeBlock build() {
724706
builder.addStatement("return ($T) $L", List.class, context.localVariable("resultList"));
725707

726708
} else if (returnCount) {
727-
builder.addStatement("return $T.valueOf($L.size())", methodReturnType,
709+
builder.addStatement("return $T.valueOf($L.size())", methodReturn.getActualClassName(),
728710
context.localVariable("resultList"));
729711
} else {
730712

731-
if (Optional.class.isAssignableFrom(methodReturnType)) {
732-
builder.addStatement("return ($1T) $1T.ofNullable($2L.isEmpty() ? null : $2L.iterator().next())",
733-
Optional.class, context.localVariable("resultList"));
734-
} else {
735-
builder.addStatement("return ($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", actualReturnType,
736-
context.localVariable("resultList"));
737-
}
713+
builder.addStatement(LordOfTheStrings.returning(methodReturn.toClass())
714+
.optional("($1T) ($2L.isEmpty() ? null : $2L.iterator().next())", typeToRead,
715+
context.localVariable("resultList")) //
716+
.build());
738717
}
739718
} else if (aotQuery != null && aotQuery.isExists()) {
740719
builder.addStatement("return !$L.getResultList().isEmpty()", queryVariableName);
741720
} else if (aotQuery != null) {
742721

743-
if (context.getReturnedType().isProjecting()) {
722+
if (isProjecting) {
723+
724+
TypeName returnType = TypeNames.typeNameOrWrapper(methodReturn.getActualType());
725+
CodeBlock convertTo;
726+
if (StringUtils.hasText(context.getDynamicProjectionParameterName())) {
727+
convertTo = CodeBlock.of("$L", context.getDynamicProjectionParameterName());
728+
} else {
729+
730+
if (methodReturn.isArray() && methodReturn.getActualType().toClass().equals(byte.class)) {
731+
returnType = TypeName.get(byte[].class);
732+
convertTo = CodeBlock.of("$T.class", returnType);
733+
} else {
734+
convertTo = CodeBlock.of("$T.class", TypeNames.classNameOrWrapper(methodReturn.getActualType()));
735+
}
736+
}
744737

745738
if (queryMethod.isCollectionQuery()) {
746-
builder.addStatement("return ($T) convertMany($L.getResultList(), $L, $T.class)",
747-
context.getReturnTypeName(), queryVariableName, aotQuery.isNative(), queryResultType);
739+
builder.addStatement("return ($T) convertMany($L.getResultList(), $L, $L)", methodReturn.getTypeName(),
740+
queryVariableName, aotQuery.isNative(), convertTo);
748741
} else if (queryMethod.isStreamQuery()) {
749-
builder.addStatement("return ($T) convertMany($L.getResultStream(), $L, $T.class)",
750-
context.getReturnTypeName(), queryVariableName, aotQuery.isNative(), queryResultType);
742+
builder.addStatement("return ($T) convertMany($L.getResultStream(), $L, $L)", methodReturn.getTypeName(),
743+
queryVariableName, aotQuery.isNative(), convertTo);
751744
} else if (queryMethod.isPageQuery()) {
752745
builder.addStatement(
753-
"return $T.getPage(($T<$T>) convertMany($L.getResultList(), $L, $T.class), $L, $L)",
754-
PageableExecutionUtils.class, List.class, actualReturnType, queryVariableName, aotQuery.isNative(),
755-
queryResultType, pageable, context.localVariable("countAll"));
746+
"return $T.getPage(($T<$T>) convertMany($L.getResultList(), $L, $L), $L, $L)",
747+
PageableExecutionUtils.class, List.class, TypeNames.typeNameOrWrapper(methodReturn.getActualType()),
748+
queryVariableName, aotQuery.isNative(), convertTo, pageable, context.localVariable("countAll"));
756749
} else if (queryMethod.isSliceQuery()) {
757-
builder.addStatement("$T<$T> $L = ($T<$T>) convertMany($L.getResultList(), $L, $T.class)", List.class,
758-
actualReturnType, context.localVariable("resultList"), List.class, actualReturnType, queryVariableName,
750+
builder.addStatement("$T<$T> $L = ($T<$T>) convertMany($L.getResultList(), $L, $L)", List.class,
751+
TypeNames.typeNameOrWrapper(methodReturn.getActualType()), context.localVariable("resultList"),
752+
List.class, typeToRead, queryVariableName,
759753
aotQuery.isNative(),
760-
queryResultType);
754+
convertTo);
761755
builder.addStatement("boolean $L = $L.isPaged() && $L.size() > $L.getPageSize()",
762756
context.localVariable("hasNext"), pageable, context.localVariable("resultList"), pageable);
763757
builder.addStatement(
@@ -766,27 +760,24 @@ public CodeBlock build() {
766760
pageable, context.localVariable("resultList"), pageable, context.localVariable("hasNext"));
767761
} else {
768762

769-
if (Optional.class.isAssignableFrom(context.getReturnType().toClass())) {
770-
builder.addStatement("return $T.ofNullable(($T) convertOne($L.getSingleResultOrNull(), $L, $T.class))",
771-
Optional.class, actualReturnType, queryVariableName, aotQuery.isNative(), queryResultType);
772-
} else {
773-
builder.addStatement("return ($T) convertOne($L.getSingleResultOrNull(), $L, $T.class)",
774-
context.getReturnTypeName(), queryVariableName, aotQuery.isNative(), queryResultType);
775-
}
763+
builder.addStatement(LordOfTheStrings.returning(methodReturn.toClass())
764+
.optional("($T) convertOne($L.getSingleResultOrNull(), $L, $L)", returnType, queryVariableName,
765+
aotQuery.isNative(), convertTo) //
766+
.build());
776767
}
777768

778769
} else {
779770

780771
if (queryMethod.isCollectionQuery()) {
781-
builder.addStatement("return ($T) $L.getResultList()", context.getReturnTypeName(), queryVariableName);
772+
builder.addStatement("return ($T) $L.getResultList()", methodReturn.getTypeName(), queryVariableName);
782773
} else if (queryMethod.isStreamQuery()) {
783-
builder.addStatement("return ($T) $L.getResultStream()", context.getReturnTypeName(), queryVariableName);
774+
builder.addStatement("return ($T) $L.getResultStream()", methodReturn.getTypeName(), queryVariableName);
784775
} else if (queryMethod.isPageQuery()) {
785776
builder.addStatement("return $T.getPage(($T<$T>) $L.getResultList(), $L, $L)",
786-
PageableExecutionUtils.class, List.class, actualReturnType, queryVariableName,
777+
PageableExecutionUtils.class, List.class, typeToRead, queryVariableName,
787778
pageable, context.localVariable("countAll"));
788779
} else if (queryMethod.isSliceQuery()) {
789-
builder.addStatement("$T<$T> $L = $L.getResultList()", List.class, actualReturnType,
780+
builder.addStatement("$T<$T> $L = $L.getResultList()", List.class, typeToRead,
790781
context.localVariable("resultList"), queryVariableName);
791782
builder.addStatement("boolean $L = $L.isPaged() && $L.size() > $L.getPageSize()",
792783
context.localVariable("hasNext"), pageable, context.localVariable("resultList"), pageable);
@@ -796,15 +787,11 @@ public CodeBlock build() {
796787
pageable, context.localVariable("resultList"), pageable, context.localVariable("hasNext"));
797788
} else {
798789

799-
if (Optional.class.isAssignableFrom(context.getReturnType().toClass())) {
800-
builder.addStatement("return $T.ofNullable(($T) convertOne($L.getSingleResultOrNull(), $L, $T.class))",
801-
Optional.class, actualReturnType, queryVariableName, aotQuery.isNative(),
802-
queryResultType);
803-
} else {
804-
builder.addStatement("return ($T) convertOne($L.getSingleResultOrNull(), $L, $T.class)",
805-
context.getReturnTypeName(), queryVariableName, aotQuery.isNative(),
806-
context.getReturnType().toClass());
807-
}
790+
builder.addStatement(LordOfTheStrings.returning(methodReturn.toClass())
791+
.optional("($T) convertOne($L.getSingleResultOrNull(), $L, $T.class)",
792+
TypeNames.typeNameOrWrapper(methodReturn.getActualType()), queryVariableName, aotQuery.isNative(),
793+
TypeNames.classNameOrWrapper(methodReturn.getActualType())) //
794+
.build());
808795
}
809796
}
810797
}
@@ -820,4 +807,5 @@ public static boolean returnsModifying(Class<?> returnType) {
820807

821808
}
822809

810+
823811
}

spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/aot/JpaRepositoryContributor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ private Optional<Class<QueryEnhancerSelector>> getQueryEnhancerSelectorClass() {
202202
queryMethod);
203203

204204
// no KeysetScrolling for now.
205-
if (parameters.hasScrollPositionParameter()) {
205+
if (parameters.hasScrollPositionParameter() || queryMethod.isScrollQuery()) {
206206
return MethodContributor.forQueryMethod(queryMethod)
207207
.metadataOnly(aotQueries.toMetadata(queryMethod.isPageQuery()));
208208
}

0 commit comments

Comments
 (0)