diff --git a/pom.xml b/pom.xml index 13143c9f6f..359521e8c6 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-commons - 4.0.0-SNAPSHOT + 4.0.x-GH-3374-SNAPSHOT Spring Data Core Core Spring concepts underpinning every Spring Data module. diff --git a/src/main/java/org/springframework/data/javapoet/TypeNames.java b/src/main/java/org/springframework/data/javapoet/TypeNames.java index eb9db1a9b9..6112af9e22 100644 --- a/src/main/java/org/springframework/data/javapoet/TypeNames.java +++ b/src/main/java/org/springframework/data/javapoet/TypeNames.java @@ -15,7 +15,12 @@ */ package org.springframework.data.javapoet; +import java.util.Arrays; + import org.springframework.core.ResolvableType; +import org.springframework.javapoet.ArrayTypeName; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.javapoet.TypeName; import org.springframework.util.ClassUtils; @@ -28,6 +33,7 @@ * Mainly for internal use within the framework * * @author Mark Paluch + * @author Christoph Strobl * @since 4.0 */ public abstract class TypeNames { @@ -65,6 +71,42 @@ public static TypeName className(ResolvableType resolvableType) { return TypeName.get(resolvableType.toClass()); } + /** + * Obtain a {@link TypeName} for the underlying type of the given {@link ResolvableType}. Can render a class name, a + * type signature with resolved generics or a generic type variable. + * + * @param resolvableType the resolvable type represent. + * @return the corresponding {@link TypeName}. + */ + public static TypeName resolvedTypeName(ResolvableType resolvableType) { + + if (resolvableType.equals(ResolvableType.NONE)) { + return TypeName.get(Object.class); + } + + if (resolvableType.hasResolvableGenerics()) { + return ParameterizedTypeName.get(ClassName.get(resolvableType.toClass()), + Arrays.stream(resolvableType.getGenerics()).map(TypeNames::resolvedTypeName).toArray(TypeName[]::new)); + } + + if (!resolvableType.hasGenerics()) { + + Class resolvedType = resolvableType.toClass(); + + if (!resolvableType.isArray() || resolvedType.isArray()) { + return TypeName.get(resolvedType); + } + + if (resolvableType.isArray()) { + return ArrayTypeName.of(resolvedType); + } + + return TypeName.get(resolvedType); + } + + return ClassName.get(resolvableType.toClass()); + } + /** * Obtain a {@link TypeName} for the underlying type of the given {@link ResolvableType}. Can render a class name, a * type signature or a generic type variable. @@ -98,7 +140,7 @@ public static TypeName typeNameOrWrapper(Class type) { public static TypeName typeNameOrWrapper(ResolvableType resolvableType) { return ClassUtils.isPrimitiveOrWrapper(resolvableType.toClass()) ? TypeName.get(ClassUtils.resolvePrimitiveIfNecessary(resolvableType.toClass())) - : typeName(resolvableType); + : resolvedTypeName(resolvableType); } private TypeNames() {} diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java index 4534592ba8..24dbf83deb 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryCreator.java @@ -16,11 +16,17 @@ package org.springframework.data.repository.aot.generate; import java.lang.reflect.Method; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.lang.reflect.WildcardType; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.Consumer; +import java.util.function.Predicate; import javax.lang.model.element.Modifier; @@ -28,12 +34,14 @@ import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; +import org.springframework.core.ResolvableType; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.support.RepositoryComposition; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.util.Lazy; +import org.springframework.data.util.TypeInformation; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.FieldSpec; import org.springframework.javapoet.MethodSpec; @@ -290,6 +298,7 @@ private void contributeMethod(Method method, @Nullable MethodContributorFactory MethodContributor contributor = contributorFactory.create(method); if (contributor == null) { + if (logger.isTraceEnabled()) { logger.trace("Skipping method [%s.%s] contribution, no MethodContributor available" .formatted(repositoryInformation.getRepositoryInterface().getName(), method.getName())); @@ -298,11 +307,241 @@ private void contributeMethod(Method method, @Nullable MethodContributorFactory return; } - if (contributor.contributesMethodSpec() && !repositoryInformation.isReactiveRepository()) { - generationMetadata.addRepositoryMethod(method, contributor); - } else { + if (ResolvableGenerics.of(method, repositoryInformation.getRepositoryInterface()).hasUnresolvableGenerics()) { + + if (logger.isTraceEnabled()) { + logger.trace( + "Skipping implementation method [%s.%s] contribution. Method uses generics that currently cannot be resolved." + .formatted(repositoryInformation.getRepositoryInterface().getName(), method.getName())); + } + + generationMetadata.addDelegateMethod(method, contributor); + return; + } + + if (!contributor.contributesMethodSpec() || repositoryInformation.isReactiveRepository()) { + + if (repositoryInformation.isReactiveRepository() && logger.isTraceEnabled()) { + logger.trace( + "Skipping implementation method [%s.%s] contribution. AOT repositories are not supported for reactive repositories." + .formatted(repositoryInformation.getRepositoryInterface().getName(), method.getName())); + } + + if (!contributor.contributesMethodSpec() && logger.isTraceEnabled()) { + logger.trace( + "Skipping implementation method [%s.%s] contribution. Spring Data %s did not provide a method implementation." + .formatted(repositoryInformation.getRepositoryInterface().getName(), method.getName(), moduleName)); + } + generationMetadata.addDelegateMethod(method, contributor); + return; + } + + generationMetadata.addRepositoryMethod(method, contributor); + } + + /** + * Value object to determine whether generics in a given {@link Method} can be resolved. Resolvable generics are e.g. + * declared on the method level (unbounded type variables, type variables using class boundaries). Considers + * collections and map types. + *

+ * Considers resolvable: + *

+ * Considers non-resolvable: + * + *

+ * + * @author Mark Paluch + */ + record ResolvableGenerics(Method method, Class implClass, Set resolvableTypeVariables, + Set unwantedMethodVariables) { + + /** + * Create a new {@code ResolvableGenerics} object for the given {@link Method}. + * + * @param method + * @return + */ + public static ResolvableGenerics of(Method method, Class implClass) { + return new ResolvableGenerics(method, implClass, getResolvableTypeVariables(method), + getUnwantedMethodVariables(method)); + } + + private static Set getResolvableTypeVariables(Method method) { + + Set simpleTypeVariables = new HashSet<>(); + + for (TypeVariable typeParameter : method.getTypeParameters()) { + if (isClassBounded(typeParameter.getBounds())) { + simpleTypeVariables.add(typeParameter); + } + } + + return simpleTypeVariables; + } + + private static Set getUnwantedMethodVariables(Method method) { + + Set unwanted = new HashSet<>(); + + for (TypeVariable typeParameter : method.getTypeParameters()) { + if (!isClassBounded(typeParameter.getBounds())) { + unwanted.add(typeParameter); + } + } + return unwanted; + } + + /** + * Check whether the {@link Method} has unresolvable generics when being considered in the context of the + * implementation class. + * + * @return + */ + public boolean hasUnresolvableGenerics() { + + ResolvableType resolvableType = ResolvableType.forMethodReturnType(method, implClass); + + if (isUnresolvable(resolvableType)) { + return true; + } + + for (int i = 0; i < method.getParameterCount(); i++) { + if (isUnresolvable(ResolvableType.forMethodParameter(method, i, implClass))) { + return true; + } + } + + return false; + } + + private boolean isUnresolvable(TypeInformation typeInformation) { + return isUnresolvable(typeInformation.toResolvableType()); + } + + private boolean isUnresolvable(ResolvableType resolvableType) { + + if (isResolvable(resolvableType)) { + return false; + } + + if (isUnwanted(resolvableType)) { + return true; + } + + if (resolvableType.isAssignableFrom(Class.class)) { + return isUnresolvable(resolvableType.getGeneric(0)); + } + + TypeInformation typeInformation = TypeInformation.of(resolvableType); + if (typeInformation.isMap() || typeInformation.isCollectionLike()) { + + for (ResolvableType type : resolvableType.getGenerics()) { + if (isUnresolvable(type)) { + return true; + } + } + + return false; + } + + if (typeInformation.getActualType() != null && typeInformation.getActualType() != typeInformation) { + return isUnresolvable(typeInformation.getRequiredActualType()); + } + + return resolvableType.hasUnresolvableGenerics(); } + + private boolean isResolvable(Type[] types) { + + for (Type type : types) { + + if (resolvableTypeVariables.contains(type)) { + continue; + } + + if (isClass(type)) { + continue; + } + + return false; + } + + return true; + } + + private boolean isResolvable(ResolvableType resolvableType) { + + return testGenericType(resolvableType, it -> { + + if (resolvableTypeVariables.contains(it)) { + return true; + } + + if (it instanceof WildcardType wt) { + return isClassBounded(wt.getLowerBounds()) && isClassBounded(wt.getUpperBounds()); + } + + return false; + }); + } + + private boolean isUnwanted(ResolvableType resolvableType) { + + return testGenericType(resolvableType, o -> { + + if (o instanceof WildcardType wt) { + return !isResolvable(wt.getLowerBounds()) || !isResolvable(wt.getUpperBounds()); + } + + return unwantedMethodVariables.contains(o); + }); + } + + private static boolean testGenericType(ResolvableType resolvableType, Predicate predicate) { + + if (predicate.test(resolvableType.getType())) { + return true; + } + + ResolvableType[] generics = resolvableType.getGenerics(); + for (ResolvableType generic : generics) { + if (testGenericType(generic, predicate)) { + return true; + } + } + + return false; + } + + private static boolean isClassBounded(Type[] bounds) { + + for (Type bound : bounds) { + + if (isClass(bound)) { + continue; + } + + return false; + } + + return true; + } + + private static boolean isClass(Type type) { + return type instanceof Class; + } + } /** diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java index fe1ca30080..c775c355d5 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryFragmentMetadata.java @@ -27,6 +27,7 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.ResolvableType; +import org.springframework.data.javapoet.TypeNames; import org.springframework.data.repository.core.support.RepositoryFragment; import org.springframework.data.repository.query.QueryMethod; import org.springframework.javapoet.ParameterizedTypeName; @@ -147,26 +148,10 @@ public Map getDelegateMethods() { } static TypeName typeNameOf(ResolvableType type) { + return TypeNames.resolvedTypeName(type); + } - if (type.equals(ResolvableType.NONE)) { - return TypeName.get(Object.class); - } - - if (!type.hasResolvableGenerics()) { - return TypeName.get(type.getType()); - } - - return ParameterizedTypeName.get(type.toClass(), type.resolveGenerics()); - } - - /** - * Constructor argument metadata. - * - * @param parameterName - * @param parameterType - * @param bindToField - */ - public record ConstructorArgument(String parameterName, ResolvableType parameterType, boolean bindToField, + public record ConstructorArgument(String parameterName, ResolvableType parameterType, boolean bindToField, AotRepositoryConstructorBuilder.ParameterOrigin parameterOrigin) { boolean isBoundToField() { diff --git a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java index 49e6bce7c6..e4ae9c13d8 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilder.java @@ -24,6 +24,7 @@ import javax.lang.model.element.Modifier; +import org.springframework.data.javapoet.TypeNames; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.MethodSpec; import org.springframework.javapoet.ParameterSpec; @@ -101,7 +102,7 @@ public MethodSpec buildMethod() { private MethodSpec.Builder initializeMethodBuilder() { MethodSpec.Builder builder = MethodSpec.methodBuilder(context.getMethod().getName()).addModifiers(Modifier.PUBLIC); - builder.returns(TypeName.get(context.getReturnType().getType())); + builder.returns(TypeNames.resolvedTypeName(context.getTargetMethodMetadata().getReturnType())); TypeVariable[] tvs = context.getMethod().getTypeParameters(); for (TypeVariable tv : tvs) { diff --git a/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java b/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java index febbc6d1ac..cebc6013ea 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/MethodMetadata.java @@ -31,6 +31,7 @@ import org.springframework.core.MethodParameter; import org.springframework.core.ParameterNameDiscoverer; import org.springframework.core.ResolvableType; +import org.springframework.data.javapoet.TypeNames; import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.util.TypeInformation; import org.springframework.javapoet.ParameterSpec; @@ -80,11 +81,11 @@ private static void initializeMethodArguments(Method method, ParameterNameDiscov for (Parameter parameter : method.getParameters()) { - MethodParameter methodParameter = MethodParameter.forParameter(parameter); + MethodParameter methodParameter = MethodParameter.forParameter(parameter).withContainingClass(repositoryInterface.resolve()); methodParameter.initParameterNameDiscovery(nameDiscoverer); - ResolvableType resolvableParameterType = ResolvableType.forMethodParameter(methodParameter, repositoryInterface); + ResolvableType resolvableParameterType = ResolvableType.forMethodParameter(methodParameter); - TypeName parameterType = TypeName.get(resolvableParameterType.getType()); + TypeName parameterType = TypeNames.resolvedTypeName(resolvableParameterType); ParameterSpec parameterSpec = ParameterSpec.builder(parameterType, methodParameter.getParameterName()).build(); diff --git a/src/main/java/org/springframework/data/repository/aot/generate/MethodReturn.java b/src/main/java/org/springframework/data/repository/aot/generate/MethodReturn.java index c8b4c4f296..2e3b3098e3 100644 --- a/src/main/java/org/springframework/data/repository/aot/generate/MethodReturn.java +++ b/src/main/java/org/springframework/data/repository/aot/generate/MethodReturn.java @@ -61,7 +61,7 @@ public MethodReturn(ReturnedType returnedType, ResolvableType returnType) { this.returnedType = returnedType; this.returnType = returnType; - this.typeName = TypeNames.typeName(returnType); + this.typeName = TypeNames.resolvedTypeName(returnType); this.className = TypeNames.className(returnType); Class returnClass = returnType.toClass(); @@ -72,7 +72,7 @@ public MethodReturn(ReturnedType returnedType, ResolvableType returnType) { if (actualType != null) { this.actualType = actualType.toResolvableType(); - this.actualTypeName = TypeNames.typeName(this.actualType); + this.actualTypeName = TypeNames.resolvedTypeName(this.actualType); this.actualClassName = TypeNames.className(this.actualType); this.actualReturnClass = actualType.getType(); } else { diff --git a/src/test/java/org/springframework/data/javapoet/TypeNamesUnitTests.java b/src/test/java/org/springframework/data/javapoet/TypeNamesUnitTests.java index 74a28f90c9..af78cd1053 100644 --- a/src/test/java/org/springframework/data/javapoet/TypeNamesUnitTests.java +++ b/src/test/java/org/springframework/data/javapoet/TypeNamesUnitTests.java @@ -15,20 +15,35 @@ */ package org.springframework.data.javapoet; -import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat; +import static org.assertj.core.api.AssertionsForInterfaceTypes.*; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; import java.util.Set; import java.util.stream.Stream; +import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoResult; +import org.springframework.data.geo.Point; +import org.springframework.javapoet.ClassName; import org.springframework.javapoet.ParameterizedTypeName; import org.springframework.javapoet.TypeName; +import org.springframework.javapoet.TypeVariableName; +import org.springframework.util.ReflectionUtils; /** + * Tests for {@link TypeNames}. + * * @author Christoph Strobl + * @author Mark Paluch */ class TypeNamesUnitTests { @@ -59,4 +74,125 @@ void classNames(ResolvableType resolvableType, TypeName expected) { assertThat(TypeNames.className(resolvableType)).isEqualTo(expected); } + @Test // GH-3374 + void resolvedTypeNamesWithoutGenerics() { + + ResolvableType resolvableType = ResolvableType.forClass(List.class); + assertThat(TypeNames.resolvedTypeName(resolvableType)).extracting(TypeName::toString).isEqualTo("java.util.List"); + } + + static List concreteMethods() { + + List methods = new ArrayList<>(); + + ReflectionUtils.doWithMethods(Concrete.class, method -> { + if (!method.getName().contains("baseMethod")) { + return; + } + methods.add(method); + }); + + return methods; + } + + static List otherMethods() { + + List methods = new ArrayList<>(); + + ReflectionUtils.doWithMethods(Concrete.class, method -> { + if (!method.getName().contains("otherMethod")) { + return; + } + methods.add(method); + }); + + return methods; + } + + @ParameterizedTest // GH-3374 + @MethodSource("concreteMethods") + void resolvedTypeNamesForMethodParameters(Method method) { + + MethodParameter refiedObjectMethodParameter = new MethodParameter(method, 0).withContainingClass(Concrete.class); + ResolvableType resolvedObjectParameterType = ResolvableType.forMethodParameter(refiedObjectMethodParameter); + assertThat(TypeNames.typeName(resolvedObjectParameterType)).isEqualTo(TypeVariableName.get("T")); + assertThat(TypeNames.resolvedTypeName(resolvedObjectParameterType)).isEqualTo(TypeName.get(MyType.class)); + + MethodParameter refiedCollectionMethodParameter = new MethodParameter(method, 1) + .withContainingClass(Concrete.class); + ResolvableType resolvedCollectionParameterType = ResolvableType.forMethodParameter(refiedCollectionMethodParameter); + assertThat(TypeNames.typeName(resolvedCollectionParameterType)) + .isEqualTo(ParameterizedTypeName.get(ClassName.get(java.util.List.class), TypeVariableName.get("T"))); + assertThat(TypeNames.resolvedTypeName(resolvedCollectionParameterType)) + .isEqualTo(ParameterizedTypeName.get(java.util.List.class, MyType.class)); + + MethodParameter refiedArrayMethodParameter = new MethodParameter(method, 2).withContainingClass(Concrete.class); + ResolvableType resolvedArrayParameterType = ResolvableType.forMethodParameter(refiedArrayMethodParameter); + assertThat(TypeNames.typeName(resolvedArrayParameterType)).extracting(TypeName::toString).isEqualTo("T[]"); + assertThat(TypeNames.resolvedTypeName(resolvedArrayParameterType)).extracting(TypeName::toString) + .isEqualTo("org.springframework.data.javapoet.TypeNamesUnitTests.MyType[]"); + + ResolvableType resolvedReturnType = ResolvableType.forMethodReturnType(method, Concrete.class); + assertThat(TypeNames.typeName(resolvedReturnType)) + .isEqualTo(ParameterizedTypeName.get(ClassName.get(java.util.List.class), TypeVariableName.get("T"))); + assertThat(TypeNames.resolvedTypeName(resolvedReturnType)) + .isEqualTo(ParameterizedTypeName.get(java.util.List.class, MyType.class)); + + } + + @ParameterizedTest // GH-3374 + @MethodSource("otherMethods") + void resolvedTypeNamesForOtherMethodParameters(Method method) { + + MethodParameter refiedObjectMethodParameter = new MethodParameter(method, 0).withContainingClass(Concrete.class); + ResolvableType resolvedObjectParameterType = ResolvableType.forMethodParameter(refiedObjectMethodParameter); + assertThat(TypeNames.typeName(resolvedObjectParameterType)).isEqualTo(TypeVariableName.get("RT")); + assertThat(TypeNames.resolvedTypeName(resolvedObjectParameterType)).isEqualTo(TypeName.get(Object.class)); + + MethodParameter refiedCollectionMethodParameter = new MethodParameter(method, 1) + .withContainingClass(Concrete.class); + ResolvableType resolvedCollectionParameterType = ResolvableType.forMethodParameter(refiedCollectionMethodParameter); + assertThat(TypeNames.typeName(resolvedCollectionParameterType)) + .isEqualTo(ParameterizedTypeName.get(ClassName.get(java.util.List.class), TypeVariableName.get("RT"))); + assertThat(TypeNames.resolvedTypeName(resolvedCollectionParameterType)) + .isEqualTo(ClassName.get(java.util.List.class)); + + MethodParameter refiedArrayMethodParameter = new MethodParameter(method, 2).withContainingClass(Concrete.class); + ResolvableType resolvedArrayParameterType = ResolvableType.forMethodParameter(refiedArrayMethodParameter); + assertThat(TypeNames.typeName(resolvedArrayParameterType)).extracting(TypeName::toString).isEqualTo("RT[]"); + assertThat(TypeNames.resolvedTypeName(resolvedArrayParameterType)).extracting(TypeName::toString) + .isEqualTo("java.lang.Object[]"); + + ResolvableType resolvedReturnType = ResolvableType.forMethodReturnType(method, Concrete.class); + assertThat(TypeNames.typeName(resolvedReturnType)).extracting(TypeName::toString).isEqualTo("RT"); + assertThat(TypeNames.resolvedTypeName(resolvedReturnType)).isEqualTo(TypeName.get(Object.class)); + } + + @Test // GH-3374 + void resolvesTypeNamesForMethodParameters() throws NoSuchMethodException { + + Method method = Concrete.class.getDeclaredMethod("findByLocationNear", Point.class, Distance.class); + + ResolvableType resolvedReturnType = ResolvableType.forMethodReturnType(method, Concrete.class); + + assertThat(TypeNames.typeName(resolvedReturnType)).extracting(TypeName::toString).isEqualTo( + "java.util.List>"); + assertThat(TypeNames.resolvedTypeName(resolvedReturnType)).isEqualTo(ParameterizedTypeName + .get(ClassName.get(java.util.List.class), ParameterizedTypeName.get(GeoResult.class, MyType.class))); + } + + interface GenericBase { + + java.util.List baseMethod(T arg0, java.util.List arg1, T... arg2); + + RT otherMethod(RT arg0, java.util.List arg1, RT... arg2); + } + + interface Concrete extends GenericBase { + + List> findByLocationNear(Point point, Distance maxDistance); + } + + static class MyType {} + } diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryCreatorUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryCreatorUnitTests.java index 5ed0352f7d..393e05bcc4 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryCreatorUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryCreatorUnitTests.java @@ -16,22 +16,30 @@ package org.springframework.data.repository.aot.generate; import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assumptions.*; import static org.mockito.Mockito.*; import example.UserRepository.User; +import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.TimeZone; +import java.util.stream.Stream; import javax.lang.model.element.Modifier; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Answers; import org.springframework.aot.generate.Generated; import org.springframework.aot.hint.TypeReference; import org.springframework.core.ResolvableType; +import org.springframework.data.domain.Range; import org.springframework.data.geo.Metric; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.querydsl.QuerydslPredicateExecutor; @@ -40,6 +48,7 @@ import org.springframework.data.repository.core.RepositoryInformation; import org.springframework.data.repository.core.support.AnnotationRepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFragment; +import org.springframework.data.repository.query.DefaultParameters; import org.springframework.data.repository.query.QueryMethod; import org.springframework.javapoet.ClassName; import org.springframework.javapoet.JavaFile; @@ -64,7 +73,8 @@ void beforeEach() { doReturn(UserRepository.class).when(repositoryInformation).getRepositoryInterface(); } - @Test // GH-3279 + @Test + // GH-3279 void writesClassSkeleton() { AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", @@ -77,7 +87,8 @@ void writesClassSkeleton() { .contains("public UserRepositoryImpl"); // default constructor if not arguments to wire } - @Test // GH-3279 + @Test + // GH-3279 void appliesCtorArguments() { AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", @@ -101,7 +112,8 @@ void appliesCtorArguments() { .doesNotContain("this.ctorScoped = ctorScoped"); } - @Test // GH-3279 + @Test + // GH-3279 void appliesCtorCodeBlock() { AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", @@ -115,7 +127,8 @@ void appliesCtorCodeBlock() { "UserRepositoryImpl() { throw new IllegalStateException(\"initialization error\"); }"); } - @Test // GH-3279 + @Test + // GH-3279 void appliesClassCustomizations() { AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", @@ -143,7 +156,8 @@ void appliesClassCustomizations() { .containsIgnoringWhitespaces("void oops() { }"); } - @Test // GH-3279 + @Test + // GH-3279 void appliesQueryMethodContributor() { AotRepositoryInformation repositoryInformation = new AotRepositoryInformation( @@ -172,15 +186,16 @@ public boolean contributesMethodSpec() { .containsIgnoringWhitespaces("void oops() { }"); } - @Test // GH-3279 + @Test + // GH-3279 void shouldContributeFragmentImplementationMetadata() { AotRepositoryInformation repositoryInformation = new AotRepositoryInformation( AnnotationRepositoryMetadata.getMetadata(QuerydslUserRepository.class), CrudRepository.class, List.of(RepositoryFragment.structural(QuerydslPredicateExecutor.class, DummyQuerydslPredicateExecutor.class))); - AotRepositoryCreator creator = AotRepositoryCreator - .forRepository(repositoryInformation, "Commons", new SpelAwareProxyProjectionFactory()); + AotRepositoryCreator creator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", + new SpelAwareProxyProjectionFactory()); creator.contributeMethods(method -> null); AotRepositoryCreator.AotBundle bundle = doCreate(creator); @@ -192,7 +207,8 @@ void shouldContributeFragmentImplementationMetadata() { assertThat(method.fragment().implementation()).isEqualTo(DummyQuerydslPredicateExecutor.class.getName()); } - @Test // GH-3339 + @Test + // GH-3339 void usesTargetTypeName() { AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", @@ -210,7 +226,8 @@ void usesTargetTypeName() { .contains("public %s(Metric param1, String param2, Object ctorScoped)".formatted(targetType.getSimpleName())); } - @Test // GH-3339 + @Test + // GH-3339 void usesGenericConstructorArguments() { AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", @@ -229,6 +246,87 @@ void usesGenericConstructorArguments() { "public %s(List param1, String param2, Object ctorScoped)".formatted(targetType.getSimpleName())); } + @Test + // GH-3374 + void skipsMethodWithUnresolvableGenericReturnType() { + + SpelAwareProxyProjectionFactory spelAwareProxyProjectionFactory = new SpelAwareProxyProjectionFactory(); + AotRepositoryInformation repositoryInformation = new AotRepositoryInformation( + AnnotationRepositoryMetadata.getMetadata(UserRepository.class), CrudRepository.class, + List.of(RepositoryFragment.structural(QuerydslPredicateExecutor.class, DummyQuerydslPredicateExecutor.class))); + + AotRepositoryCreator repositoryCreator = AotRepositoryCreator.forRepository(repositoryInformation, "Commons", + spelAwareProxyProjectionFactory); + repositoryCreator.contributeMethods(method -> { + + QueryMethod queryMethod = new QueryMethod(method, repositoryInformation, spelAwareProxyProjectionFactory, + DefaultParameters::new); + return new MethodContributor<>(queryMethod, Map::of) { + @Override + public MethodSpec contribute(AotQueryMethodGenerationContext context) { + return MethodSpec.methodBuilder(context.getMethod().getName()).addCode("// 1 = 1").build(); + } + + @Override + public boolean contributesMethodSpec() { + return true; + } + }; + + }); + + // same package as source repo + String generated = generate(repositoryCreator); + + assertThat(generated).contains("someMethod").contains("findByFirstname").contains("project1ByFirstname") + .contains("project2ByFirstname").contains("geoQuery").contains("rangeQuery"); + + assertThat(generated).doesNotContain("baseProjection").doesNotContain("upperBoundedProjection") + .doesNotContain("lowerBoundedProjection()"); + } + + static Stream declaredUserRepositoryMethods() { + return Arrays.stream(UserRepository.class.getDeclaredMethods()); + } + + static Stream unresolvedRepositoryMethods() { + return Arrays.stream(UserRepository.class.getMethods()) + .filter(it -> it.getDeclaringClass().equals(BaseRepository.class)) + .filter(it -> it.getName().startsWith("upper") || it.getName().startsWith("lower")); + } + + static Stream resolvedRepositoryMethods() { + return Arrays.stream(UserRepository.class.getMethods()) + .filter(it -> it.getDeclaringClass().equals(BaseRepository.class)) + .filter(it -> it.getName().startsWith("parametrized")); + } + + @ParameterizedTest + @MethodSource("declaredUserRepositoryMethods") + void shouldResolveGenerics(Method method) { + + assertThat(AotRepositoryCreator.ResolvableGenerics.of(method, UserRepository.class).hasUnresolvableGenerics()) + .isFalse(); + } + + @ParameterizedTest + @MethodSource("resolvedRepositoryMethods") + void shouldResolveInterfaceGenerics(Method method) { + + assertThat(AotRepositoryCreator.ResolvableGenerics.of(method, UserRepository.class).hasUnresolvableGenerics()) + .isFalse(); + } + + @ParameterizedTest + @MethodSource("unresolvedRepositoryMethods") + void shouldReportUnresolvedGenerics(Method method) { + + assumeThat(method.getDeclaringClass()).isEqualTo(BaseRepository.class); + + assertThat(AotRepositoryCreator.ResolvableGenerics.of(method, UserRepository.class).hasUnresolvableGenerics()) + .isTrue(); + } + private AotRepositoryCreator.AotBundle doCreate(AotRepositoryCreator creator) { return creator.create(getTypeSpecBuilder(creator)); } @@ -260,11 +358,36 @@ private static TypeSpec.Builder getTypeSpecBuilder(ClassName className) { return TypeSpec.classBuilder(className).addAnnotation(Generated.class); } - interface UserRepository extends org.springframework.data.repository.Repository { + interface BaseRepository extends org.springframework.data.repository.Repository { + +

List

upperBoundedProjection(String firstname, Class type); + + List lowerBoundedProjection(String firstname, Class type); + + List parametrizedListProjection(String firstname, Class type); + + T parametrizedSelection(String firstname); + + } + + interface UserRepository extends BaseRepository { String someMethod(); + + List findByFirstname(String firstname, Class type); + + List project1ByFirstname(String firstname, Class type); + + List project2ByFirstname(String firstname, Class type); + + List geoQuery(GeoJson geoJson); + + List rangeQuery(Range geoJson); + } + public interface GeoJson> {} + interface QuerydslUserRepository extends org.springframework.data.repository.Repository, QuerydslPredicateExecutor { diff --git a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java index d2d2510d39..9c8abafdfc 100644 --- a/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java +++ b/src/test/java/org/springframework/data/repository/aot/generate/AotRepositoryMethodBuilderUnitTests.java @@ -60,7 +60,6 @@ void generatesMethodSkeletonBasedOnGenerationMetadata() throws NoSuchMethodExcep Method method = UserRepository.class.getMethod("findByFirstname", String.class); when(methodGenerationContext.getMethod()).thenReturn(method); - when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class)); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); @@ -76,11 +75,10 @@ void generatesMethodWithGenerics() throws NoSuchMethodException { Method method = UserRepository.class.getMethod("findByFirstnameIn", List.class); when(methodGenerationContext.getMethod()).thenReturn(method); - when(methodGenerationContext.getReturnType()) - .thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class)); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); - MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method); + MethodMetadata methodMetadata = spy(new MethodMetadata(repositoryInformation, method)); + when(methodMetadata.getReturnType()).thenReturn(ResolvableType.forClassWithGenerics(List.class, User.class)); when(methodGenerationContext.getTargetMethodMetadata()).thenReturn(methodMetadata); AotRepositoryMethodBuilder builder = new AotRepositoryMethodBuilder(methodGenerationContext); @@ -95,7 +93,6 @@ void generatesExpressionMarkerIfInUse(ExpressionMarker expressionMarker) throws Method method = UserRepository.class.getMethod("findByFirstname", String.class); when(methodGenerationContext.getMethod()).thenReturn(method); - when(methodGenerationContext.getReturnType()).thenReturn(ResolvableType.forClass(User.class)); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnType(any()); doReturn(TypeInformation.of(User.class)).when(repositoryInformation).getReturnedDomainTypeInformation(any()); MethodMetadata methodMetadata = new MethodMetadata(repositoryInformation, method);