Skip to content

Commit ef01847

Browse files
Allow Custom Streamable return type.
1 parent cae6e15 commit ef01847

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.springframework.data.util.Streamable;
3939
import org.springframework.javapoet.CodeBlock;
4040
import org.springframework.javapoet.CodeBlock.Builder;
41+
import org.springframework.util.ClassUtils;
4142
import org.springframework.util.NumberUtils;
4243
import org.springframework.util.StringUtils;
4344

@@ -238,9 +239,25 @@ static void appendReadPreference(AotQueryMethodGenerationContext context, Builde
238239
* {@link MethodReturn} indicates so.
239240
*/
240241
public static CodeBlock potentiallyWrapStreamable(MethodReturn methodReturn, CodeBlock returningIterable) {
241-
return methodReturn.toClass().equals(Streamable.class)
242-
? CodeBlock.of("$T.of($L)", Streamable.class, returningIterable)
243-
: returningIterable;
242+
243+
Class<?> returnType = methodReturn.toClass();
244+
if (returnType.equals(Streamable.class)) {
245+
return CodeBlock.of("$T.of($L)", Streamable.class, returningIterable);
246+
}
247+
if (ClassUtils.isAssignable(Streamable.class, returnType)) {
248+
CodeBlock streamable = CodeBlock.of("$T.of($L)", Streamable.class, returningIterable);
249+
if (ClassUtils.hasConstructor(returnType, Streamable.class)) {
250+
return CodeBlock.of("new $T($L)", returnType, streamable);
251+
}
252+
if (ClassUtils.hasAtLeastOneMethodWithName(returnType, "of")) {
253+
return CodeBlock.of("$T.of($L)", returnType, streamable);
254+
}
255+
if (ClassUtils.hasAtLeastOneMethodWithName(returnType, "valueOf")) {
256+
return CodeBlock.of("$T.valueOf($L)", returnType, streamable);
257+
}
258+
}
259+
260+
return returningIterable;
244261
}
245262

246263
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
import org.springframework.data.mongodb.core.query.Query;
7474
import org.springframework.data.mongodb.core.query.Update;
7575
import org.springframework.data.mongodb.repository.Person.Sex;
76+
import org.springframework.data.mongodb.repository.PersonRepository.Persons;
7677
import org.springframework.data.mongodb.repository.SampleEvaluationContextExtension.SampleSecurityContextHolder;
7778
import org.springframework.data.mongodb.test.util.DirtiesStateExtension;
7879
import org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesState;
@@ -324,6 +325,17 @@ void streamPersonByAddressCorrectly() {
324325
assertThat(result).hasSize(1).contains(dave);
325326
}
326327

328+
@Test // GH-5089
329+
void useCustomReturnTypeImplementingStreamable() {
330+
331+
Address address = new Address("Foo Street 1", "C0123", "Bar");
332+
dave.setAddress(address);
333+
repository.save(dave);
334+
335+
Persons result = repository.streamPersonsByAddress(address);
336+
assertThat(result).hasSize(1).contains(dave);
337+
}
338+
327339
@Test // GH-5089
328340
void streamPersonByAddressCorrectlyWhenPaged() {
329341

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.util.Collection;
1919
import java.util.Date;
20+
import java.util.Iterator;
2021
import java.util.List;
2122
import java.util.Optional;
2223
import java.util.UUID;
@@ -214,6 +215,8 @@ Window<Person> findByLastnameLikeOrderByLastnameAscFirstnameAsc(Pattern lastname
214215

215216
Streamable<Person> streamByAddress(Address address);
216217

218+
Persons streamPersonsByAddress(Address address);
219+
217220
Streamable<Person> streamByAddress(Address address, Pageable pageable);
218221

219222
List<Person> findByAddressZipCode(String zipCode);
@@ -502,4 +505,17 @@ Person findPersonByManyArguments(String firstname, String lastname, String email
502505

503506
List<Person> findBySpiritAnimal(User user);
504507

508+
class Persons implements Streamable<Person> {
509+
510+
private final Streamable<Person> streamable;
511+
512+
public Persons(Streamable<Person> streamable) {
513+
this.streamable = streamable;
514+
}
515+
516+
@Override
517+
public Iterator<Person> iterator() {
518+
return streamable.iterator();
519+
}
520+
}
505521
}

0 commit comments

Comments
 (0)