diff --git a/pom.xml b/pom.xml index 98f945a6b2..0b2a3448c3 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-relational-parent - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT pom Spring Data Relational Parent diff --git a/spring-data-jdbc-distribution/pom.xml b/spring-data-jdbc-distribution/pom.xml index b3c39e64c3..6e8f165747 100644 --- a/spring-data-jdbc-distribution/pom.xml +++ b/spring-data-jdbc-distribution/pom.xml @@ -14,7 +14,7 @@ org.springframework.data spring-data-relational-parent - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT ../pom.xml diff --git a/spring-data-jdbc/pom.xml b/spring-data-jdbc/pom.xml index e61fd64020..5a9c5eb8e9 100644 --- a/spring-data-jdbc/pom.xml +++ b/spring-data-jdbc/pom.xml @@ -6,7 +6,7 @@ 4.0.0 spring-data-jdbc - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT Spring Data JDBC Spring Data module for JDBC repositories. @@ -15,7 +15,7 @@ org.springframework.data spring-data-relational-parent - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT diff --git a/spring-data-r2dbc/pom.xml b/spring-data-r2dbc/pom.xml index 64ff1ebcb3..3be76bc14d 100644 --- a/spring-data-r2dbc/pom.xml +++ b/spring-data-r2dbc/pom.xml @@ -6,7 +6,7 @@ 4.0.0 spring-data-r2dbc - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC @@ -15,7 +15,7 @@ org.springframework.data spring-data-relational-parent - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java index 82f96e1e30..a0198bbee0 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java @@ -57,6 +57,7 @@ * * @author Mark Paluch * @author Oliver Drotbohm + * @author Jens Schauder */ public class MappingR2dbcConverter extends MappingRelationalConverter implements R2dbcConverter { @@ -189,8 +190,17 @@ private void writeInternal(Object source, OutboundRow sink, Class userClass) writeProperties(sink, entity, propertyAccessor); } + /** + * write the values of the properties of an {@link RelationalPersistentEntity} to an {@link OutboundRow}. + * + * @param sink must not be {@literal null}. + * @param entity must not be {@literal null}. + * @param accessor used for accessing the property values of {@literal entity}. May be {@literal null}. A + * {@literal null} value is used when this is an embedded {@literal null} entity, resulting in all its + * property values to be {@literal null} as well. + */ private void writeProperties(OutboundRow sink, RelationalPersistentEntity entity, - PersistentPropertyAccessor accessor) { + @Nullable PersistentPropertyAccessor accessor) { for (RelationalPersistentProperty property : entity) { @@ -200,11 +210,27 @@ private void writeProperties(OutboundRow sink, RelationalPersistentEntity ent Object value; - if (property.isIdProperty()) { - IdentifierAccessor identifierAccessor = entity.getIdentifierAccessor(accessor.getBean()); - value = identifierAccessor.getIdentifier(); + if (accessor == null) { + value = null; } else { - value = accessor.getProperty(property); + if (property.isIdProperty()) { + IdentifierAccessor identifierAccessor = entity.getIdentifierAccessor(accessor.getBean()); + value = identifierAccessor.getIdentifier(); + } else { + value = accessor.getProperty(property); + } + } + + if (property.isEmbedded()) { + + RelationalPersistentEntity embeddedEntity = getMappingContext().getRequiredPersistentEntity(property); + PersistentPropertyAccessor embeddedAccessor = null; + if (value != null) { + embeddedAccessor = embeddedEntity.getPropertyAccessor(value); + } + writeProperties(sink, embeddedEntity, embeddedAccessor); + + continue; } if (value == null) { diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategy.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategy.java index d655464e82..c1f918f23e 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategy.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategy.java @@ -43,6 +43,7 @@ import org.springframework.data.relational.core.dialect.ArrayColumns; import org.springframework.data.relational.core.dialect.Dialect; import org.springframework.data.relational.core.dialect.RenderContextFactory; +import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.sql.SqlIdentifier; @@ -66,7 +67,7 @@ public class DefaultReactiveDataAccessStrategy implements ReactiveDataAccessStra private final R2dbcDialect dialect; private final R2dbcConverter converter; private final UpdateMapper updateMapper; - private final MappingContext, ? extends RelationalPersistentProperty> mappingContext; + private final RelationalMappingContext mappingContext; private final StatementMapper statementMapper; private final NamedParameterExpander expander = new NamedParameterExpander(); @@ -119,7 +120,6 @@ public static R2dbcConverter createConverter(R2dbcDialect dialect, Collection * @param dialect the {@link R2dbcDialect} to use. * @param converter must not be {@literal null}. */ - @SuppressWarnings("unchecked") public DefaultReactiveDataAccessStrategy(R2dbcDialect dialect, R2dbcConverter converter) { Assert.notNull(dialect, "Dialect must not be null"); @@ -127,8 +127,7 @@ public DefaultReactiveDataAccessStrategy(R2dbcDialect dialect, R2dbcConverter co this.converter = converter; this.updateMapper = new UpdateMapper(dialect, converter); - this.mappingContext = (MappingContext, ? extends RelationalPersistentProperty>) this.converter - .getMappingContext(); + this.mappingContext = (RelationalMappingContext) this.converter.getMappingContext(); this.dialect = dialect; RenderContextFactory factory = new RenderContextFactory(dialect); @@ -141,13 +140,22 @@ public List getAllColumns(Class entityType) { RelationalPersistentEntity persistentEntity = getPersistentEntity(entityType); + return getAllColumns(persistentEntity); + } + + private List getAllColumns(@Nullable RelationalPersistentEntity persistentEntity) { + if (persistentEntity == null) { return Collections.singletonList(SqlIdentifier.unquoted("*")); } List columnNames = new ArrayList<>(); for (RelationalPersistentProperty property : persistentEntity) { - columnNames.add(property.getColumnName()); + if (property.isEmbedded()) { + columnNames.addAll(getAllColumns(mappingContext.getRequiredPersistentEntity(property))); + } else { + columnNames.add(property.getColumnName()); + } } return columnNames; @@ -159,12 +167,8 @@ public List getIdentifierColumns(Class entityType) { RelationalPersistentEntity persistentEntity = getRequiredPersistentEntity(entityType); List columnNames = new ArrayList<>(); - for (RelationalPersistentProperty property : persistentEntity) { - - if (property.isIdProperty()) { - columnNames.add(property.getColumnName()); - } - } + mappingContext.getAggregatePath(persistentEntity).getTableInfo().idColumnInfos() + .forEach((__, ci) -> columnNames.add(ci.name())); return columnNames; } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java index a7fcf2a13e..24f7f46a5b 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/DefaultStatementMapper.java @@ -143,7 +143,7 @@ protected List getSelectList(SelectSpec selectSpec, @Nullable Relati List mapped = new ArrayList<>(selectList.size()); for (Expression expression : selectList) { - mapped.add(updateMapper.getMappedObject(expression, entity)); + mapped.addAll(updateMapper.getMappedObjects(expression, entity)); } return mapped; diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index f277c6266f..353c550fdb 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -23,6 +23,7 @@ import reactor.core.publisher.Mono; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -96,6 +97,7 @@ * @author Robert Heim * @author Sebastian Wieland * @author Mikhail Polivakha + * @author Jens Schauder * @since 1.1 */ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware, ApplicationContextAware { @@ -350,8 +352,8 @@ > P doSelect(Query query, Class entityClass, SqlIde return (P) ((Flux) result).concatMap(it -> maybeCallAfterConvert(it, tableName)); } - private RowsFetchSpec doSelect(Query query, Class entityType, SqlIdentifier tableName, - Class returnType, Function filterFunction) { + private RowsFetchSpec doSelect(Query query, Class entityType, SqlIdentifier tableName, Class returnType, + Function filterFunction) { StatementMapper statementMapper = dataAccessStrategy.getStatementMapper().forType(entityType); @@ -378,11 +380,8 @@ private RowsFetchSpec doSelect(Query query, Class entityType, SqlIdent PreparedOperation operation = statementMapper.getMappedObject(selectSpec); - return getRowsFetchSpec( - databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)), - entityType, - returnType - ); + return getRowsFetchSpec(databaseClient.sql(operation).filter(statementFilterFunction.andThen(filterFunction)), + entityType, returnType); } @Override @@ -622,8 +621,13 @@ private Mono doUpdate(T entity, SqlIdentifier tableName) { return maybeCallBeforeSave(entityToUse, outboundRow, tableName) // .flatMap(onBeforeSave -> { - SqlIdentifier idColumn = persistentEntity.getRequiredIdProperty().getColumnName(); - Parameter id = outboundRow.remove(idColumn); + Map idValues = new LinkedHashMap<>(); + List identifierColumns = dataAccessStrategy.getIdentifierColumns(persistentEntity.getType()); + Assert.state(!identifierColumns.isEmpty(), entityToUse + " has no Identifier. Update is not possible."); + + identifierColumns.forEach(sqlIdentifier -> { + idValues.put(sqlIdentifier, outboundRow.remove(sqlIdentifier)); + }); persistentEntity.forEach(p -> { if (p.isInsertOnly()) { @@ -631,7 +635,14 @@ private Mono doUpdate(T entity, SqlIdentifier tableName) { } }); - Criteria criteria = Criteria.where(dataAccessStrategy.toSql(idColumn)).is(id); + Criteria criteria = null; + for (Map.Entry idAndValue : idValues.entrySet()) { + if (criteria == null) { + criteria = Criteria.where(dataAccessStrategy.toSql(idAndValue.getKey())).is(idAndValue.getValue()); + } else { + criteria = criteria.and(dataAccessStrategy.toSql(idAndValue.getKey())).is(idAndValue.getValue()); + } + } if (matchingVersionCriteria != null) { criteria = criteria.and(matchingVersionCriteria); diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java index f6ee60dd02..709e72a90a 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/QueryMapper.java @@ -24,6 +24,8 @@ import org.springframework.data.domain.Sort; import org.springframework.data.mapping.MappingException; +import org.springframework.data.mapping.PersistentProperty; +import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.mapping.PersistentPropertyPath; import org.springframework.data.mapping.PropertyPath; import org.springframework.data.mapping.PropertyReferenceException; @@ -112,22 +114,42 @@ public List getMappedSort(Table table, Sort sort, @Nullable Relati SqlSort.validate(order); - OrderByField simpleOrderByField = createSimpleOrderByField(table, entity, order); - OrderByField orderBy = simpleOrderByField.withNullHandling(order.getNullHandling()); - mappedOrder.add(order.isAscending() ? orderBy.asc() : orderBy.desc()); + List simpleOrderByFields = createSimpleOrderByFields(table, entity, order); + + simpleOrderByFields.forEach(field -> { + + OrderByField orderBy = field.withNullHandling(order.getNullHandling()); + mappedOrder.add(order.isAscending() ? orderBy.asc() : orderBy.desc()); + }); } return mappedOrder; } - private OrderByField createSimpleOrderByField(Table table, RelationalPersistentEntity entity, Sort.Order order) { + private List createSimpleOrderByFields(Table table, @Nullable RelationalPersistentEntity entity, + Sort.Order order) { if (order instanceof SqlSort.SqlOrder sqlOrder && sqlOrder.isUnsafe()) { - return OrderByField.from(Expressions.just(sqlOrder.getProperty())); + return List.of(OrderByField.from(Expressions.just(sqlOrder.getProperty()))); } Field field = createPropertyField(entity, SqlIdentifier.unquoted(order.getProperty()), this.mappingContext); - return OrderByField.from(table.column(field.getMappedColumnName())); + + if (field.isEmbedded() && entity != null) { + + RelationalPersistentEntity embeddedEntity = getMappingContext() + .getRequiredPersistentEntity(field.getRequiredProperty()); + + List fields = new ArrayList<>(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + fields.addAll(createSimpleOrderByFields(table, embeddedEntity, order.withProperty(embeddedProperty.getName()))); + } + + return fields; + } + + return List.of(OrderByField.from(table.column(field.getMappedColumnName()))); } /** @@ -140,9 +162,28 @@ private OrderByField createSimpleOrderByField(Table table, RelationalPersistentE */ public Expression getMappedObject(Expression expression, @Nullable RelationalPersistentEntity entity) { + List mappedObjects = getMappedObjects(expression, entity); + + if (mappedObjects.isEmpty()) { + throw new IllegalArgumentException(String.format("Cannot map %s", expression)); + } + + return mappedObjects.get(0); + } + + /** + * Map the {@link Expression} object to apply field name mapping using {@link Class the type to read}. + * + * @param expression must not be {@literal null}. + * @param entity related {@link RelationalPersistentEntity}, can be {@literal null}. + * @return the mapped {@link Expression}s. + * @since 4.0 + */ + public List getMappedObjects(Expression expression, @Nullable RelationalPersistentEntity entity) { + if (entity == null || expression instanceof AsteriskFromTable || expression instanceof Expressions.SimpleExpression) { - return expression; + return List.of(expression); } if (expression instanceof Column column) { @@ -150,8 +191,22 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer Field field = createPropertyField(entity, column.getName()); TableLike table = column.getTable(); + if (field.isEmbedded()) { + + RelationalPersistentEntity embeddedEntity = getMappingContext() + .getRequiredPersistentEntity(field.getRequiredProperty()); + + List expressions = new ArrayList<>(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + expressions.addAll(getMappedObjects(Column.create(embeddedProperty.getName(), table), embeddedEntity)); + } + + return expressions; + } + Column columnFromTable = table.column(field.getMappedColumnName()); - return column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable; + return List.of(column instanceof Aliased ? columnFromTable.as(((Aliased) column).getAlias()) : columnFromTable); } if (expression instanceof SimpleFunction function) { @@ -160,12 +215,12 @@ public Expression getMappedObject(Expression expression, @Nullable RelationalPer List mappedArguments = new ArrayList<>(arguments.size()); for (Expression argument : arguments) { - mappedArguments.add(getMappedObject(argument, entity)); + mappedArguments.addAll(getMappedObjects(argument, entity)); } SimpleFunction mappedFunction = SimpleFunction.create(function.getFunctionName(), mappedArguments); - return function instanceof Aliased ? mappedFunction.as(((Aliased) function).getAlias()) : mappedFunction; + return List.of(function instanceof Aliased ? mappedFunction.as(((Aliased) function).getAlias()) : mappedFunction); } throw new IllegalArgumentException(String.format("Cannot map %s", expression)); @@ -297,6 +352,43 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind @Nullable RelationalPersistentEntity entity) { Field propertyField = createPropertyField(entity, criteria.getColumn(), this.mappingContext); + + if (propertyField.isEmbedded() && entity != null) { + + Object value = criteria.getValue(); + + RelationalPersistentEntity embeddedEntity = mappingContext + .getRequiredPersistentEntity(propertyField.getRequiredProperty()); + PersistentPropertyAccessor propertyAccessor = getEmbeddedPropertyAccessor(value, embeddedEntity, + propertyField); + + Condition condition = Conditions.unrestricted(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + + Object propertyValue = propertyAccessor.getProperty(embeddedProperty); + + CriteriaWrapper cw = new CriteriaWrapper(criteria) { + + @Override + public SqlIdentifier getColumn() { + return SqlIdentifier.unquoted(embeddedProperty.getName()); + } + + @Nullable + @Override + public Object getValue() { + return propertyValue; + } + }; + + Condition mapped = mapCondition(cw, bindings, table, embeddedEntity); + condition = condition.and(mapped); + } + + return condition; + } + Column column = table.column(propertyField.getMappedColumnName()); TypeInformation actualType = propertyField.getTypeHint().getRequiredActualType(); @@ -321,6 +413,39 @@ private Condition mapCondition(CriteriaDefinition criteria, MutableBindings bind } return createCondition(column, mappedValue, typeHint, bindings, comparator, criteria.isIgnoreCase()); + + } + + static PersistentPropertyAccessor getEmbeddedPropertyAccessor(@Nullable Object value, + RelationalPersistentEntity embeddedEntity, Field propertyField) { + + if (value != null) { + + Class propertyType = embeddedEntity.getType(); + if (!propertyType.isInstance(value)) { + throw new IllegalArgumentException("Value of property " + propertyField.getRequiredProperty().getName() + + " is not an instance of " + embeddedEntity.getType().getName() + " but " + value.getClass().getName()); + } + + return embeddedEntity.getPropertyAccessor(value); + } + + return new PersistentPropertyAccessor<>() { + @Override + public void setProperty(PersistentProperty property, @org.jspecify.annotations.Nullable Object value) { + + } + + @Override + public @org.jspecify.annotations.Nullable Object getProperty(PersistentProperty property) { + return null; + } + + @Override + public Object getBean() { + return null; + } + }; } private Escaper getEscaper(Comparator comparator) { @@ -587,6 +712,25 @@ public SqlIdentifier getMappedColumnName() { public TypeInformation getTypeHint() { return TypeInformation.OBJECT; } + + public boolean isEmbedded() { + return false; + } + + public @org.jspecify.annotations.Nullable RelationalPersistentProperty getProperty() { + return null; + } + + public RelationalPersistentProperty getRequiredProperty() { + + RelationalPersistentProperty property = getProperty(); + + if (property == null) { + throw new IllegalStateException("No property found for field: " + this.name); + } + + return property; + } } /** @@ -633,13 +777,34 @@ protected MetadataBackedField(SqlIdentifier name, RelationalPersistentEntity this.mappingContext = context; this.path = getPath(name.getReference()); - this.property = this.path == null ? property : this.path.getLeafProperty(); + + RelationalPersistentProperty persistentProperty = null; + if (this.path != null) { + + RelationalPersistentEntity currentEntity = entity; + RelationalPersistentProperty currentProperty = null; + for (RelationalPersistentProperty p : path) { + + currentProperty = currentEntity.getPersistentProperty(p.getName()); + + if (currentProperty == null) { + break; + } + + if (currentProperty.isEntity()) { + currentEntity = mappingContext.getRequiredPersistentEntity(currentProperty); + } + } + + persistentProperty = currentProperty; + } + + this.property = persistentProperty; } @Override public SqlIdentifier getMappedColumnName() { - return this.path == null || this.path.getLeafProperty() == null ? super.getMappedColumnName() - : this.path.getLeafProperty().getColumnName(); + return this.property == null ? super.getMappedColumnName() : this.property.getColumnName(); } /** @@ -700,5 +865,91 @@ public TypeInformation getTypeHint() { return this.property.getTypeInformation(); } + + @Override + public boolean isEmbedded() { + return this.property != null && this.property.isEmbedded(); + } + + @Override + public @org.jspecify.annotations.Nullable RelationalPersistentProperty getProperty() { + return this.property; + } + } + + abstract static class CriteriaWrapper extends AbstractCriteria { + + private final CriteriaDefinition delegate; + + public CriteriaWrapper(CriteriaDefinition delegate) { + this.delegate = delegate; + } + + @Nullable + @Override + public Comparator getComparator() { + return delegate.getComparator(); + } + + @Override + public boolean isIgnoreCase() { + return delegate.isIgnoreCase(); + } + } + + abstract static class AbstractCriteria implements CriteriaDefinition { + @Override + public boolean isGroup() { + return false; + } + + @Override + public List getGroup() { + return List.of(); + } + + @Nullable + @Override + public SqlIdentifier getColumn() { + return null; + } + + @Nullable + @Override + public Comparator getComparator() { + return null; + } + + @Nullable + @Override + public Object getValue() { + return null; + } + + @Override + public boolean isIgnoreCase() { + return false; + } + + @Nullable + @Override + public CriteriaDefinition getPrevious() { + return null; + } + + @Override + public boolean hasPrevious() { + return false; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public Combinator getCombinator() { + return null; + } } } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java index fb9eec7ed4..4070959e2c 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/query/UpdateMapper.java @@ -16,13 +16,16 @@ package org.springframework.data.r2dbc.query; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; +import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.r2dbc.convert.R2dbcConverter; import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.relational.core.dialect.Escaper; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.query.Update; import org.springframework.data.relational.core.query.ValueFunction; import org.springframework.data.relational.core.sql.AssignValue; @@ -94,23 +97,42 @@ public BoundAssignments getMappedObject(BindMarkers markers, Map result = new ArrayList<>(); assignments.forEach((column, value) -> { - Assignment assignment = getAssignment(column, value, bindings, table, entity); - result.add(assignment); + result.addAll(getAssignments(column, value, bindings, table, entity)); }); return new BoundAssignments(bindings, result); } - private Assignment getAssignment(SqlIdentifier columnName, Object value, MutableBindings bindings, Table table, - @Nullable RelationalPersistentEntity entity) { + private Collection getAssignments(SqlIdentifier columnName, Object value, MutableBindings bindings, + Table table, @Nullable RelationalPersistentEntity entity) { Field propertyField = createPropertyField(entity, columnName, getMappingContext()); + + if (propertyField.isEmbedded() && entity != null) { + + RelationalPersistentEntity embeddedEntity = getMappingContext() + .getRequiredPersistentEntity(propertyField.getRequiredProperty()); + PersistentPropertyAccessor propertyAccessor = getEmbeddedPropertyAccessor(value, embeddedEntity, + propertyField); + + List assignments = new ArrayList<>(); + + for (RelationalPersistentProperty embeddedProperty : embeddedEntity) { + + Object propertyValue = propertyAccessor.getProperty(embeddedProperty); + + assignments.addAll(getAssignments(SqlIdentifier.unquoted(embeddedProperty.getName()), propertyValue, bindings, + table, embeddedEntity)); + } + + return assignments; + } + Column column = table.column(propertyField.getMappedColumnName()); TypeInformation actualType = propertyField.getTypeHint().getRequiredActualType(); Object mappedValue; Class typeHint; - if (value instanceof Parameter parameter) { mappedValue = convertValue(parameter.getValue(), propertyField.getTypeHint()); @@ -121,7 +143,7 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable mappedValue = valueFunction.map(v -> convertValue(v, propertyField.getTypeHint())).apply(Escaper.DEFAULT); if (mappedValue == null) { - return Assignments.value(column, SQL.nullLiteral()); + return List.of(Assignments.value(column, SQL.nullLiteral())); } typeHint = actualType.getType(); @@ -130,13 +152,13 @@ private Assignment getAssignment(SqlIdentifier columnName, Object value, Mutable mappedValue = convertValue(value, propertyField.getTypeHint()); if (mappedValue == null) { - return Assignments.value(column, SQL.nullLiteral()); + return List.of(Assignments.value(column, SQL.nullLiteral())); } typeHint = actualType.getType(); } - return createAssignment(column, mappedValue, typeHint, bindings); + return List.of(createAssignment(column, mappedValue, typeHint, bindings)); } private Assignment createAssignment(Column column, Object value, Class type, MutableBindings bindings) { diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java index 2bed00a42a..5c56f45ccb 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverterUnitTests.java @@ -43,6 +43,7 @@ import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.mapping.OutboundRow; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.r2dbc.core.Parameter; @@ -261,6 +262,53 @@ void writeShouldObtainIdFromIdentifierAccessor() { assertThat(row).containsEntry(SqlIdentifier.unquoted("id"), Parameter.from(42L)); } + @Test // GH-2096 + void shouldWriteSingleLevelEmbeddedEntity() { + + Level1 entity = new Level1("root", new Level2("child", 23)); + + OutboundRow row = new OutboundRow(); + converter.write(entity, row); + + assertThat(row).containsExactlyInAnyOrderEntriesOf(Map.of( + SqlIdentifier.unquoted("name"), Parameter.from("root"), + SqlIdentifier.unquoted("level2_name"), Parameter.from("child"), + SqlIdentifier.unquoted("level2_number"), Parameter.from(23) + )); + } + + @Test // GH-2096 + void shouldWriteMultiLevelEmbeddedEntity() { + + WithEmbedded entity = new WithEmbedded(4711L, new Level1("level1", new Level2("child", 23))); + + OutboundRow row = new OutboundRow(); + converter.write(entity, row); + + assertThat(row).containsExactlyInAnyOrderEntriesOf(Map.of( + SqlIdentifier.unquoted("id"), Parameter.from(4711L), + SqlIdentifier.unquoted("level1_name"), Parameter.from("level1"), + SqlIdentifier.unquoted("level1_level2_name"), Parameter.from("child"), + SqlIdentifier.unquoted("level1_level2_number"), Parameter.from(23) + )); + } + + @Test // GH-2096 + void shouldWriteNullEmbeddedEntity() { + + WithEmbedded entity = new WithEmbedded(4711L, null); + + OutboundRow row = new OutboundRow(); + converter.write(entity, row); + + assertThat(row).containsExactlyInAnyOrderEntriesOf(Map.of( + SqlIdentifier.unquoted("id"), Parameter.from(4711L), + SqlIdentifier.unquoted("level1_name"), Parameter.empty(String.class), + SqlIdentifier.unquoted("level1_level2_name"), Parameter.empty(String.class), + SqlIdentifier.unquoted("level1_level2_number"), Parameter.empty(Integer.class) + )); + } + static class Person { @Id String id; String firstname, lastname; @@ -326,6 +374,13 @@ public PersonWithConversions(String id, Map nested, NonMappableE record WithPrimitiveId(@Id long id) { } + record WithEmbedded(@Id long id, @Embedded.Empty(prefix = "level1_") Level1 one){} + + record Level1(String name, @Embedded.Empty(prefix = "level2_") Level2 two) { + + } + record Level2(String name, Integer number){} + static class CustomConversionPerson { String foo; diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java new file mode 100644 index 0000000000..e5db9514da --- /dev/null +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/core/DefaultReactiveDataAccessStrategyUnitTests.java @@ -0,0 +1,65 @@ +package org.springframework.data.r2dbc.core; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import org.springframework.data.annotation.Id; +import org.springframework.data.r2dbc.dialect.H2Dialect; +import org.springframework.data.relational.core.mapping.Embedded; +import org.springframework.data.relational.core.sql.SqlIdentifier; + +/** + * Unit tests for {@link DefaultReactiveDataAccessStrategy}. + * + * @author Jens Schauder + */ +class DefaultReactiveDataAccessStrategyUnitTests { + + DefaultReactiveDataAccessStrategy dataAccessStrategy = new DefaultReactiveDataAccessStrategy(H2Dialect.INSTANCE); + + @ParameterizedTest + @MethodSource("fixtures") + void shouldReportAllColumns(Fixture fixture) { + + List sqlIdentifiers = Arrays.stream(fixture.allColumns()).map(SqlIdentifier::quoted).toList(); + + assertThat(dataAccessStrategy.getAllColumns(fixture.entityType())) + .containsExactlyInAnyOrder(sqlIdentifiers.toArray(new SqlIdentifier[0])); + } + + static Stream fixtures() { + return Stream.of(new Fixture(SimpleEntity.class, "ID", "NAME"), + new Fixture(WithEmbedded.class, "ID", "L1_NAME", "L1_L2_NAME", "L1_L2_NUMBER"), + new Fixture(WithEmbeddedId.class, "ID_NAME", "ID_NUMBER", "NAME")); + } + + record Fixture(Class entityType, String... allColumns) { + + @Override + public String toString() { + return entityType.getSimpleName(); + } + } + + record SimpleEntity(int id, String name) { + } + + record WithEmbedded(int id, @Embedded.Empty(prefix = "L1_") Level1 level1) { + } + + record Level1(String name, @Embedded.Empty(prefix = "L2_") Level2 l2) { + } + + record Level2(String name, Integer number) { + } + + record WithEmbeddedId(@Id @Embedded.Empty(prefix = "ID_") Level2 id, String name) { + } + +} diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java index cb7ef38a1f..71e2d78350 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/QueryMapperUnitTests.java @@ -26,6 +26,7 @@ import java.util.Objects; import org.junit.jupiter.api.Test; + import org.springframework.core.convert.converter.Converter; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.convert.MappingR2dbcConverter; @@ -36,6 +37,7 @@ import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; import org.springframework.data.relational.core.mapping.Column; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.query.Criteria; import org.springframework.data.relational.core.sql.Expression; import org.springframework.data.relational.core.sql.Functions; @@ -45,6 +47,7 @@ import org.springframework.r2dbc.core.Parameter; import org.springframework.r2dbc.core.binding.BindMarkersFactory; import org.springframework.r2dbc.core.binding.BindTarget; + import org.testcontainers.shaded.com.fasterxml.jackson.databind.JsonNode; import org.testcontainers.shaded.com.fasterxml.jackson.databind.node.TextNode; @@ -541,12 +544,94 @@ void shouldMapJsonNodeListToString() { assertThat(bindings.getBindings().iterator().next().getValue()).isEqualTo("foo"); } + @Test // GH-2096 + void shouldMapPathToEmbeddable() { + + Criteria criteria = Criteria.where("home").is(new Address(new Country("DE"))); + + BoundCondition bindings = map(criteria, WithEmbeddable.class); + + assertThat(bindings.getCondition()) + .hasToString("withembeddable.home_country_name = ?[$1] AND withembeddable.home_street = ?[$2]"); + } + + @Test // GH-2096 + void shouldMapPathToNestedEmbeddable() { + + Criteria criteria = Criteria.where("home.country").is(new Country("DE")); + + BoundCondition bindings = map(criteria, WithEmbeddable.class); + + assertThat(bindings.getCondition()).hasToString("withembeddable.home_country_name = ?[$1]"); + } + + @Test // GH-2096 + void shouldMapPathIntoEmbeddable() { + + Criteria criteria = Criteria.where("home.country.name").is("DE"); + + BoundCondition bindings = map(criteria, WithEmbeddable.class); + + assertThat(bindings.getCondition()).hasToString("withembeddable.home_country_name = ?[$1]"); + } + + @Test // GH-2096 + void shouldMapSortPathForEmbeddable() { + + List orderByFields = map(Sort.by("home"), WithEmbeddable.class); + + Table table = Table.create("withembeddable"); + assertThat(orderByFields).contains(OrderByField.from(table.column("home_country_name"), Sort.Direction.ASC)) + .contains(OrderByField.from(table.column("home_street"), Sort.Direction.ASC)); + } + + @Test // GH-2096 + void shouldMapSortPathIntoNestedEmbeddable() { + + List orderByFields = map(Sort.by("home.country"), WithEmbeddable.class); + + Table table = Table.create("withembeddable"); + assertThat(orderByFields).contains(OrderByField.from(table.column("home_country_name"), Sort.Direction.ASC)); + } + + @Test // GH-2096 + void shouldMapSortPathIntoEmbeddable() { + + List orderByFields = map(Sort.by("home.country.name"), WithEmbeddable.class); + + Table table = Table.create("withembeddable"); + assertThat(orderByFields).contains(OrderByField.from(table.column("home_country_name"), Sort.Direction.ASC)); + } + + @Test // GH-2096 + void shouldMapSelectionForEmbeddable() { + + Table table = Table.create("my_table").as("my_aliased_table"); + + List mappedObject = mapper.getMappedObjects(table.column("home"), + mapper.getMappingContext().getRequiredPersistentEntity(WithEmbeddable.class)); + + assertThat(mappedObject).extracting(Expression::toString) // + .hasSize(2) // + .contains("my_aliased_table.home_street", "my_aliased_table.home_country_name"); + } + private BoundCondition map(Criteria criteria) { + return map(criteria, Person.class); + } + + private BoundCondition map(Criteria criteria, Class entityType) { BindMarkersFactory markers = BindMarkersFactory.indexed("$", 1); - return mapper.getMappedObject(markers.create(), criteria, Table.create("person"), - mapper.getMappingContext().getRequiredPersistentEntity(Person.class)); + return mapper.getMappedObject(markers.create(), criteria, Table.create(entityType.getSimpleName().toLowerCase()), + mapper.getMappingContext().getRequiredPersistentEntity(entityType)); + } + + private List map(Sort sort, Class entityType) { + + return mapper.getMappedSort(Table.create(entityType.getSimpleName().toLowerCase()), sort, + mapper.getMappingContext().getRequiredPersistentEntity(entityType)); } static class Person { @@ -560,6 +645,32 @@ static class Person { JsonNode jsonNode; } + static class WithEmbeddable { + + @Embedded.Nullable(prefix = "home_") Address home; + + @Embedded.Nullable(prefix = "work_") Address work; + } + + static class Address { + + @Embedded.Nullable(prefix = "country_") Country country; + String street; + + public Address(Country country) { + this.country = country; + } + } + + static class Country { + + String name; + + public Country(String name) { + this.name = name; + } + } + enum MyEnum { ONE, TWO, } diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java index 60100dd713..a3fa731b91 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/query/UpdateMapperUnitTests.java @@ -27,6 +27,7 @@ import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; import org.springframework.data.relational.core.mapping.Column; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.query.Update; import org.springframework.data.relational.core.sql.AssignValue; import org.springframework.data.relational.core.sql.Expression; @@ -100,7 +101,8 @@ void shouldMapMultipleFields() { BoundAssignments mapped = map(update); - Map assignments = mapped.getAssignments().stream().map(it -> (AssignValue) it) + Map assignments = mapped.getAssignments().stream() // + .map(it -> (AssignValue) it) // .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); assertThat(update.getAssignments()).hasSize(3); @@ -108,12 +110,78 @@ void shouldMapMultipleFields() { .containsEntry(SqlIdentifier.unquoted("c2"), SQL.bindMarker("$2")); } + @Test // GH-2096 + void shouldMapPathToEmbeddable() { + + Update update = Update.update("home", new Address(new Country("DE"), "foo")); + + BoundAssignments mapped = map(update, WithEmbeddable.class); + + Map assignments = mapped.getAssignments().stream() // + .map(it -> (AssignValue) it) // + .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); + + assertThat(assignments).containsExactlyInAnyOrderEntriesOf( // + Map.of( // + SqlIdentifier.unquoted("home_country_name"), SQL.bindMarker("$1"), // + SqlIdentifier.unquoted("home_street"), SQL.bindMarker("$2")// + )); + + mapped.getBindings().forEach(it -> { + assertThat(it.getValue()).isIn("DE", "foo"); + }); + } + + @Test // GH-2096 + void shouldMapPathToNestedEmbeddable() { + + Update update = Update.update("home.country", new Country("DE")); + + BoundAssignments mapped = map(update, WithEmbeddable.class); + + Map assignments = mapped.getAssignments().stream() // + .map(it -> (AssignValue) it) // + .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); + + assertThat(assignments) // + .hasSize(1) // + .containsEntry(SqlIdentifier.unquoted("home_country_name"), SQL.bindMarker("$1")); + + mapped.getBindings().forEach(it -> { + assertThat(it.getValue()).isEqualTo("DE"); + }); + } + + @Test // GH-2096 + void shouldMapPathIntoEmbeddable() { + + Update update = Update.update("home.country.name", "DE"); + + BoundAssignments mapped = map(update, WithEmbeddable.class); + + Map assignments = mapped.getAssignments().stream()// + .map(it -> (AssignValue) it) // + .collect(Collectors.toMap(k -> k.getColumn().getName(), AssignValue::getValue)); + + assertThat(assignments) // + .hasSize(1) // + .containsEntry(SqlIdentifier.unquoted("home_country_name"), SQL.bindMarker("$1")); + + mapped.getBindings().forEach(it -> { + assertThat(it.getValue()).isEqualTo("DE"); + }); + } + private BoundAssignments map(Update update) { + return map(update, Person.class); + } + + private BoundAssignments map(Update update, Class entityType) { BindMarkersFactory markers = BindMarkersFactory.indexed("$", 1); - return mapper.getMappedObject(markers.create(), update, Table.create("person"), - converter.getMappingContext().getRequiredPersistentEntity(Person.class)); + return mapper.getMappedObject(markers.create(), update, Table.create(entityType.getSimpleName().toLowerCase()), + converter.getMappingContext().getRequiredPersistentEntity(entityType)); } static class Person { @@ -121,4 +189,35 @@ static class Person { String name; @Column("another_name") String alternative; } + + static class WithEmbeddable { + + @Embedded.Nullable(prefix = "home_") Address home; + + @Embedded.Nullable(prefix = "work_") Address work; + } + + static class Address { + + @Embedded.Nullable(prefix = "country_") Country country; + String street; + + public Address(Country country) { + this.country = country; + } + + public Address(Country country, String street) { + this.country = country; + this.street = street; + } + } + + static class Country { + + String name; + + public Country(String name) { + this.name = name; + } + } } diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/CompositeIdRepositoryIntegrationTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/CompositeIdRepositoryIntegrationTests.java index 1193773fdc..e0d05b2abe 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/CompositeIdRepositoryIntegrationTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/CompositeIdRepositoryIntegrationTests.java @@ -18,16 +18,21 @@ import static org.assertj.core.api.Assertions.*; import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; import reactor.test.StepVerifier; import java.util.Optional; +import java.util.concurrent.atomic.AtomicInteger; import javax.sql.DataSource; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.reactivestreams.Publisher; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.FilterType; @@ -36,12 +41,14 @@ import org.springframework.data.r2dbc.config.AbstractR2dbcConfiguration; import org.springframework.data.r2dbc.convert.R2dbcCustomConversions; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; +import org.springframework.data.r2dbc.mapping.event.BeforeConvertCallback; import org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories; import org.springframework.data.r2dbc.testing.H2TestSupport; import org.springframework.data.relational.RelationalManagedTypes; import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.mapping.NamingStrategy; import org.springframework.data.relational.core.mapping.Table; +import org.springframework.data.relational.core.sql.SqlIdentifier; import org.springframework.data.repository.reactive.ReactiveCrudRepository; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.test.context.junit.jupiter.SpringExtension; @@ -76,6 +83,24 @@ public R2dbcMappingContext r2dbcMappingContext(Optional namingSt return context; } + + @Bean + BeforeConvertCallback beforeConvertCallback() { + + return new BeforeConvertCallback<>() { + AtomicInteger counter = new AtomicInteger(); + + @Override + public Publisher onBeforeConvert(WithCompositeId entity, SqlIdentifier table) { + + if (entity.pk == null) { + CompositeId pk = new CompositeId(counter.incrementAndGet(), "generated"); + entity = new WithCompositeId(pk, entity.name); + } + return Mono.just(entity); + } + }; + } } @BeforeEach @@ -117,15 +142,71 @@ protected ConnectionFactory createConnectionFactory() { @Test // GH-574 void findAllById() { + repository.findById(new CompositeId(42, "HBAR")) // .as(StepVerifier::create) // - .consumeNextWith(actual -> { + .assertNext(actual -> { assertThat(actual.name).isEqualTo("Walter"); + assertThat(actual.pk.one).isEqualTo(42); + assertThat(actual.pk.two).isEqualTo("HBAR"); }).verifyComplete(); } - interface WithCompositeIdRepository extends ReactiveCrudRepository { + @Test // GH-2096 + void findByName() { + + repository.findByName("Walter") // + .as(StepVerifier::create) // + .assertNext(actual -> { + assertThat(actual.name).isEqualTo("Walter"); + assertThat(actual.pk.one).isEqualTo(42); + assertThat(actual.pk.two).isEqualTo("HBAR"); + }).verifyComplete(); + } + + @Test // GH-2096 + void insert() { + + repository.save(new WithCompositeId(null, "Jane Margolis"))// + .as(StepVerifier::create) // + .assertNext(actual -> assertThat(actual.pk).isNotNull()).verifyComplete(); + } + + @Test // GH-2096 + void update() { + + insert(); + + repository.findByName("Jane Margolis") // + .map(wci -> new WithCompositeId(wci.pk, "Jane")) // + .flatMap(repository::save) // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + // nothing to be found under the old name + repository.findByName("Jane Margolis").as(StepVerifier::create).verifyComplete(); + + // but under the new name + repository.findByName("Jane").as(StepVerifier::create).expectNextCount(1).verifyComplete(); + } + @Test + void delete() { + + insert(); + + repository.findByName("Jane Margolis") // + .flatMap(repository::delete) // + .as(StepVerifier::create) // + .verifyComplete(); + + // nothing to be found under the old name + repository.findByName("Jane Margolis").as(StepVerifier::create).verifyComplete(); + } + + interface WithCompositeIdRepository extends ReactiveCrudRepository { + Flux findByName(String name); } @Table("with_composite_id") diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java new file mode 100644 index 0000000000..714ecf6231 --- /dev/null +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/H2R2dbcRepositoryEmbeddedIntegrationTests.java @@ -0,0 +1,177 @@ +/* + * Copyright 2018-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.repository; + +import static org.assertj.core.api.Assertions.*; + +import io.r2dbc.spi.ConnectionFactory; +import reactor.core.publisher.Hooks; +import reactor.test.StepVerifier; + +import java.util.Arrays; +import java.util.Optional; +import java.util.Set; + +import javax.sql.DataSource; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.FilterType; +import org.springframework.dao.DataAccessException; +import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Example; +import org.springframework.data.r2dbc.config.AbstractR2dbcConfiguration; +import org.springframework.data.r2dbc.convert.R2dbcCustomConversions; +import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; +import org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories; +import org.springframework.data.r2dbc.testing.H2TestSupport; +import org.springframework.data.r2dbc.testing.R2dbcIntegrationTestSupport; +import org.springframework.data.relational.RelationalManagedTypes; +import org.springframework.data.relational.core.mapping.Embedded; +import org.springframework.data.relational.core.mapping.NamingStrategy; +import org.springframework.data.repository.query.ReactiveQueryByExampleExecutor; +import org.springframework.data.repository.reactive.ReactiveCrudRepository; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +/** + * Tests for support of embedded entities. + * + * @author Jens Schauder + * @author Mark Paluch + */ +@ExtendWith(SpringExtension.class) +@ContextConfiguration +class H2R2dbcRepositoryEmbeddedIntegrationTests extends R2dbcIntegrationTestSupport { + + static { + Hooks.onOperatorDebug(); + } + + @Autowired private PersonRepository repository; + protected JdbcTemplate jdbc; + + @Configuration + @EnableR2dbcRepositories(considerNestedRepositories = true, + includeFilters = @ComponentScan.Filter(classes = PersonRepository.class, type = FilterType.ASSIGNABLE_TYPE)) + static class IntegrationTestConfiguration extends AbstractR2dbcConfiguration { + + @Bean + @Override + public ConnectionFactory connectionFactory() { + return H2TestSupport.createConnectionFactory(); + } + + @Override + public R2dbcMappingContext r2dbcMappingContext(Optional namingStrategy, + R2dbcCustomConversions r2dbcCustomConversions, RelationalManagedTypes r2dbcManagedTypes) { + + R2dbcMappingContext context = super.r2dbcMappingContext(namingStrategy, r2dbcCustomConversions, + r2dbcManagedTypes); + context.setForceQuote(false); + + return context; + } + + @Bean + public H2R2dbcRepositoryIntegrationTests.AfterConvertCallbackRecorder afterConvertCallbackRecorder() { + return new H2R2dbcRepositoryIntegrationTests.AfterConvertCallbackRecorder(); + } + } + + @BeforeEach + void before() { + + this.jdbc = createJdbcTemplate(createDataSource()); + + try { + this.jdbc.execute("DROP TABLE person"); + } catch (DataAccessException e) {} + + this.jdbc.execute(getCreateTableStatement()); + } + + /** + * Creates a {@link DataSource} to be used in this test. + * + * @return the {@link DataSource} to be used in this test. + */ + DataSource createDataSource() { + return H2TestSupport.createDataSource(); + } + + String getCreateTableStatement() { + return "create table person(id integer AUTO_INCREMENT PRIMARY KEY, name_first varchar(50), name_last varchar(50))"; + } + + @Test // GH-2096 + void shouldInsertNewItems() { + + Person frodo = new Person(null, new Name("Frodo", "Baggins")); + Person sam = new Person(null, new Name("Sam", "Gamgee")); + + repository.saveAll(Arrays.asList(frodo, sam)) // + .as(StepVerifier::create) // + .expectNextMatches(person -> person.id != null) // + .expectNextMatches(person -> person.id != null) // + .verifyComplete(); + } + + @Test // GH-2096 + void shouldReadNewItems() { + + shouldInsertNewItems(); + + Set firstNames = Set.of("Frodo", "Sam"); + + repository.findAll() // + .as(StepVerifier::create) // + .assertNext(p -> firstNames.contains(p.name.first)) // + .assertNext(p -> firstNames.contains(p.name.first)) // + .verifyComplete(); + } + + @Test // GH-2096 + void shouldFindUsingQueryByExample() { + + shouldInsertNewItems(); + + Person probe = new Person(null, new Name("Frodo", "Baggins")); + + repository.findAll(Example.of(probe)) // + .as(StepVerifier::create) // + .assertNext(p -> assertThat(p.name.first).isEqualTo("Frodo")) // + .verifyComplete(); + } + + interface PersonRepository extends ReactiveCrudRepository, ReactiveQueryByExampleExecutor {} + + record Person(@Id Integer id, @Embedded.Empty(prefix = "name_") Name name) { + + } + + record Name(String first, String last) { + + } + +} diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java index 9450c0b1b6..8dbafba7ff 100644 --- a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/PartTreeR2dbcQueryUnitTests.java @@ -50,6 +50,7 @@ import org.springframework.data.r2dbc.dialect.DialectResolver; import org.springframework.data.r2dbc.dialect.R2dbcDialect; import org.springframework.data.r2dbc.mapping.R2dbcMappingContext; +import org.springframework.data.relational.core.mapping.Embedded; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.relational.core.sql.LockMode; @@ -789,6 +790,46 @@ void createsQueryWithoutIdForExistsProjection() throws Exception { .where(TABLE + ".first_name = $1 LIMIT 1"); } + @Test // GH-2096 + void createsQueryForEmbeddable() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithEmbeddableRepository.class, "findByHome", Address.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, new Address(new Country("DE"))); + + PreparedOperationAssert.assertThat(query) // + .selects("with_embeddable.home_country_name", "with_embeddable.work_country_name") // + .from("with_embeddable") // + .where("with_embeddable.home_country_name = $1"); + } + + @Test // GH-2096 + void createsQueryForNestedEmbeddable() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithEmbeddableRepository.class, "findByHomeCountry", Country.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, new Country("DE")); + + PreparedOperationAssert.assertThat(query) // + .selects("with_embeddable.home_country_name", "with_embeddable.work_country_name") // + .from("with_embeddable") // + .where("with_embeddable.home_country_name = $1"); + } + + @Test // GH-2096 + void createsQueryForNestedEmbeddableValue() throws Exception { + + R2dbcQueryMethod queryMethod = getQueryMethod(WithEmbeddableRepository.class, "findByHomeCountryName", + String.class); + PartTreeR2dbcQuery r2dbcQuery = new PartTreeR2dbcQuery(queryMethod, operations, r2dbcConverter, dataAccessStrategy); + PreparedOperation query = createQuery(queryMethod, r2dbcQuery, "DE"); + + PreparedOperationAssert.assertThat(query) // + .selects("with_embeddable.home_country_name", "with_embeddable.work_country_name") // + .from("with_embeddable") // + .where("with_embeddable.home_country_name = $1"); + } + private PreparedOperation createQuery(R2dbcQueryMethod queryMethod, PartTreeR2dbcQuery r2dbcQuery, Object... parameters) { return createQuery(r2dbcQuery, getAccessor(queryMethod, parameters)); @@ -1001,6 +1042,15 @@ interface WithoutIdRepository extends Repository { Mono countByFirstName(String firstName); } + interface WithEmbeddableRepository extends Repository { + + Mono findByHome(Address home); + + Mono findByHomeCountry(Country homeCountry); + + Mono findByHomeCountryName(String homeCountryName); + } + @Table("users") private static class User { @@ -1038,4 +1088,29 @@ static class UserDtoProjection { String firstName; String unknown; } + + static class WithEmbeddable { + + @Embedded.Nullable(prefix = "home_") Address home; + + @Embedded.Nullable(prefix = "work_") Address work; + } + + static class Address { + + @Embedded.Nullable(prefix = "country_") Country country; + + public Address(Country country) { + this.country = country; + } + } + + static class Country { + + String name; + + public Country(String name) { + this.name = name; + } + } } diff --git a/spring-data-relational/pom.xml b/spring-data-relational/pom.xml index 8fd6d7a6f0..776e865c0f 100644 --- a/spring-data-relational/pom.xml +++ b/spring-data-relational/pom.xml @@ -6,7 +6,7 @@ 4.0.0 spring-data-relational - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT Spring Data Relational Spring Data Relational support @@ -14,7 +14,7 @@ org.springframework.data spring-data-relational-parent - 4.0.0-SNAPSHOT + 4.0.0-2011-embedded-SNAPSHOT