Skip to content

Commit

Permalink
Update @query argument conversion to handle Collection<Enum>.
Browse files Browse the repository at this point in the history
+ Copy logic from QueryMapper#convertToJdbcValue to resolve Iterable
  arguments on findBy* query methods to resolve the same for @query.
+ Use parameter ResolvableType instead of Class to retain generics info.

Original pull request #1226
Closes #1212
  • Loading branch information
ctailor2 authored and schauder committed May 17, 2022
1 parent 95f127f commit f97434d
Show file tree
Hide file tree
Showing 21 changed files with 346 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2020-2021 the original author or authors.
* Copyright 2020-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,16 +19,20 @@

import java.lang.reflect.Constructor;
import java.sql.JDBCType;
import java.util.ArrayList;
import java.util.List;

import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.converter.Converter;
import org.springframework.data.jdbc.core.convert.JdbcColumnTypes;
import org.springframework.data.jdbc.core.convert.JdbcConverter;
import org.springframework.data.jdbc.core.convert.JdbcValue;
import org.springframework.data.jdbc.support.JdbcUtil;
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
import org.springframework.data.relational.repository.query.RelationalParameterAccessor;
import org.springframework.data.relational.repository.query.RelationalParameters;
import org.springframework.data.relational.repository.query.RelationalParametersParameterAccessor;
import org.springframework.data.repository.query.Parameter;
import org.springframework.data.repository.query.Parameters;
Expand All @@ -53,6 +57,7 @@
* @author Maciej Walkowiak
* @author Mark Paluch
* @author Hebert Coelho
* @author Chirag Tailor
* @since 2.0
*/
public class StringBasedJdbcQuery extends AbstractJdbcQuery {
Expand Down Expand Up @@ -157,11 +162,34 @@ private void convertAndAddParameter(MapSqlParameterSource parameters, Parameter

String parameterName = p.getName().orElseThrow(() -> new IllegalStateException(PARAMETER_NEEDS_TO_BE_NAMED));

Class<?> parameterType = queryMethod.getParameters().getParameter(p.getIndex()).getType();
Class<?> conversionTargetType = JdbcColumnTypes.INSTANCE.resolvePrimitiveType(parameterType);
RelationalParameters.RelationalParameter parameter = queryMethod.getParameters().getParameter(p.getIndex());
ResolvableType resolvableType = parameter.getResolvableType();
Class<?> type = resolvableType.resolve();
Assert.notNull(type, "@Query parameter could not be resolved!");

JdbcValue jdbcValue = converter.writeJdbcValue(value, conversionTargetType,
JdbcUtil.sqlTypeFor(conversionTargetType));
JdbcValue jdbcValue;
if (value instanceof Iterable) {

List<Object> mapped = new ArrayList<>();
SQLType jdbcType = null;

Class<?> elementType = resolvableType.getGeneric(0).resolve();
Assert.notNull(elementType, "@Query Iterable parameter generic type could not be resolved!");
for (Object o : (Iterable<?>) value) {
JdbcValue elementJdbcValue = converter.writeJdbcValue(o, elementType,
JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(elementType)));
if (jdbcType == null) {
jdbcType = elementJdbcValue.getJdbcType();
}

mapped.add(elementJdbcValue.getValue());
}

jdbcValue = JdbcValue.of(mapped, jdbcType);
} else {
jdbcValue = converter.writeJdbcValue(value, type,
JdbcUtil.sqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(type)));
}

JDBCType jdbcType = jdbcValue.getJdbcType();
if (jdbcType == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2019-2021 the original author or authors.
* Copyright 2019-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,13 +16,15 @@
package org.springframework.data.jdbc.repository;

import static java.util.Arrays.*;
import static java.util.Collections.*;
import static org.assertj.core.api.Assertions.*;
import static org.assertj.core.api.SoftAssertions.*;
import static org.springframework.test.context.TestExecutionListeners.MergeMode.*;

import java.math.BigDecimal;
import java.sql.JDBCType;
import java.util.Date;
import java.util.List;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
Expand All @@ -36,6 +38,7 @@
import org.springframework.data.convert.WritingConverter;
import org.springframework.data.jdbc.core.convert.JdbcCustomConversions;
import org.springframework.data.jdbc.core.convert.JdbcValue;
import org.springframework.data.jdbc.repository.query.Query;
import org.springframework.data.jdbc.repository.support.JdbcRepositoryFactory;
import org.springframework.data.jdbc.testing.AssumeFeatureTestExecutionListener;
import org.springframework.data.jdbc.testing.TestConfiguration;
Expand All @@ -50,6 +53,7 @@
*
* @author Jens Schauder
* @author Sanghyuk Jung
* @author Chirag Tailor
*/
@ContextConfiguration
@Transactional
Expand All @@ -69,18 +73,19 @@ Class<?> testClass() {
}

@Bean
EntityWithBooleanRepository repository() {
return factory.getRepository(EntityWithBooleanRepository.class);
EntityWithStringyBigDecimalRepository repository() {
return factory.getRepository(EntityWithStringyBigDecimalRepository.class);
}

@Bean
JdbcCustomConversions jdbcCustomConversions() {
return new JdbcCustomConversions(asList(StringToBigDecimalConverter.INSTANCE, BigDecimalToString.INSTANCE,
CustomIdReadingConverter.INSTANCE, CustomIdWritingConverter.INSTANCE));
CustomIdReadingConverter.INSTANCE, CustomIdWritingConverter.INSTANCE, DirectionToIntegerConverter.INSTANCE,
NumberToDirectionConverter.INSTANCE, IntegerToDirectionConverter.INSTANCE));
}
}

@Autowired EntityWithBooleanRepository repository;
@Autowired EntityWithStringyBigDecimalRepository repository;

/**
* In PostrgreSQL this fails if a simple converter like the following is used.
Expand Down Expand Up @@ -143,13 +148,50 @@ public void saveAndLoadAnEntityWithReference() {
});
}

interface EntityWithBooleanRepository extends CrudRepository<EntityWithStringyBigDecimal, CustomId> {}
@Test // GH-1212
void queryByEnumTypeIn() {

EntityWithStringyBigDecimal entityA = new EntityWithStringyBigDecimal();
entityA.direction = Direction.LEFT;
EntityWithStringyBigDecimal entityB = new EntityWithStringyBigDecimal();
entityB.direction = Direction.CENTER;
EntityWithStringyBigDecimal entityC = new EntityWithStringyBigDecimal();
entityC.direction = Direction.RIGHT;
repository.saveAll(asList(entityA, entityB, entityC));

assertThat(repository.findByEnumTypeIn(asList(Direction.LEFT, Direction.RIGHT)))
.extracting(entity -> entity.direction).containsExactlyInAnyOrder(Direction.LEFT, Direction.RIGHT);
}

@Test // GH-1212
void queryByEnumTypeEqual() {

EntityWithStringyBigDecimal entityA = new EntityWithStringyBigDecimal();
entityA.direction = Direction.LEFT;
EntityWithStringyBigDecimal entityB = new EntityWithStringyBigDecimal();
entityB.direction = Direction.CENTER;
EntityWithStringyBigDecimal entityC = new EntityWithStringyBigDecimal();
entityC.direction = Direction.RIGHT;
repository.saveAll(asList(entityA, entityB, entityC));

assertThat(repository.findByEnumTypeIn(singletonList(Direction.CENTER))).extracting(entity -> entity.direction)
.containsExactly(Direction.CENTER);
}

interface EntityWithStringyBigDecimalRepository extends CrudRepository<EntityWithStringyBigDecimal, CustomId> {
@Query("SELECT * FROM ENTITY_WITH_STRINGY_BIG_DECIMAL WHERE DIRECTION IN (:types)")
List<EntityWithStringyBigDecimal> findByEnumTypeIn(List<Direction> types);

@Query("SELECT * FROM ENTITY_WITH_STRINGY_BIG_DECIMAL WHERE DIRECTION = :type")
List<EntityWithStringyBigDecimal> findByEnumType(Direction type);
}

private static class EntityWithStringyBigDecimal {

@Id CustomId id;
String stringyNumber;
String stringyNumber = "1.0";
OtherEntity reference;
Direction direction = Direction.CENTER;
}

private static class CustomId {
Expand All @@ -167,6 +209,10 @@ private static class OtherEntity {
Date created;
}

enum Direction {
LEFT, CENTER, RIGHT
}

@WritingConverter
enum StringToBigDecimalConverter implements Converter<String, JdbcValue> {

Expand Down Expand Up @@ -214,4 +260,64 @@ public CustomId convert(Number source) {
}
}

@WritingConverter
enum DirectionToIntegerConverter implements Converter<Direction, JdbcValue> {

INSTANCE;

@Override
public JdbcValue convert(Direction source) {

int integer;
switch (source) {
case LEFT:
integer = -1;
break;
case CENTER:
integer = 0;
break;
case RIGHT:
integer = 1;
break;
default:
throw new IllegalArgumentException();
}
return JdbcValue.of(integer, JDBCType.INTEGER);
}
}

@ReadingConverter // Needed for Oracle since the JDBC driver returns BigDecimal on read
enum NumberToDirectionConverter implements Converter<Number, Direction> {

INSTANCE;

@Override
public Direction convert(Number source) {
int sourceAsInt = source.intValue();
if (sourceAsInt == 0) {
return Direction.CENTER;
} else if (sourceAsInt < 0) {
return Direction.LEFT;
} else {
return Direction.RIGHT;
}
}
}

@ReadingConverter
enum IntegerToDirectionConverter implements Converter<Integer, Direction> {

INSTANCE;

@Override
public Direction convert(Integer source) {
if (source == 0) {
return Direction.CENTER;
} else if (source < 0) {
return Direction.LEFT;
} else {
return Direction.RIGHT;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@
import static org.assertj.core.api.SoftAssertions.*;
import static org.springframework.test.context.TestExecutionListeners.MergeMode.*;

import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Value;

import java.io.IOException;
import java.sql.ResultSet;
import java.time.Instant;
Expand All @@ -33,6 +29,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -73,11 +70,16 @@
import org.springframework.test.jdbc.JdbcTestUtils;
import org.springframework.transaction.annotation.Transactional;

import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.Value;

/**
* Very simple use cases for creation and usage of JdbcRepositories.
*
* @author Jens Schauder
* @author Mark Paluch
* @author Chirag Tailor
*/
@Transactional
@TestExecutionListeners(value = AssumeFeatureTestExecutionListener.class, mergeMode = MERGE_WITH_DEFAULTS)
Expand Down Expand Up @@ -565,6 +567,38 @@ void nullStringResult() {
assertThat(repository.returnInput(null)).isNull();
}

@Test // GH-1212
void queryByEnumTypeIn() {

DummyEntity dummyA = new DummyEntity("dummyA");
dummyA.setDirection(Direction.LEFT);
DummyEntity dummyB = new DummyEntity("dummyB");
dummyB.setDirection(Direction.CENTER);
DummyEntity dummyC = new DummyEntity("dummyC");
dummyC.setDirection(Direction.RIGHT);
repository.saveAll(asList(dummyA, dummyB, dummyC));

assertThat(repository.findByEnumTypeIn(asList(Direction.LEFT, Direction.RIGHT)))
.extracting(DummyEntity::getDirection)
.containsExactlyInAnyOrder(Direction.LEFT, Direction.RIGHT);
}

@Test // GH-1212
void queryByEnumTypeEqual() {

DummyEntity dummyA = new DummyEntity("dummyA");
dummyA.setDirection(Direction.LEFT);
DummyEntity dummyB = new DummyEntity("dummyB");
dummyB.setDirection(Direction.CENTER);
DummyEntity dummyC = new DummyEntity("dummyC");
dummyC.setDirection(Direction.RIGHT);
repository.saveAll(asList(dummyA, dummyB, dummyC));

assertThat(repository.findByEnumType(Direction.CENTER))
.extracting(DummyEntity::getDirection)
.containsExactlyInAnyOrder(Direction.CENTER);
}

private Instant createDummyBeforeAndAfterNow() {

Instant now = Instant.now();
Expand Down Expand Up @@ -645,6 +679,12 @@ interface DummyEntityRepository extends CrudRepository<DummyEntity, Long> {
@Query("SELECT CAST(:hello AS CHAR(5)) FROM DUMMY_ENTITY")
@Nullable
String returnInput(@Nullable String hello);

@Query("SELECT * FROM DUMMY_ENTITY WHERE DIRECTION IN (:directions)")
List<DummyEntity> findByEnumTypeIn(List<Direction> directions);

@Query("SELECT * FROM DUMMY_ENTITY WHERE DIRECTION = :direction")
List<DummyEntity> findByEnumType(Direction direction);
}

@Configuration
Expand Down Expand Up @@ -698,12 +738,17 @@ static class DummyEntity {
@Id private Long idProp;
boolean flag;
AggregateReference<DummyEntity, Long> ref;
Direction direction;

public DummyEntity(String name) {
this.name = name;
}
}

enum Direction {
LEFT, CENTER, RIGHT
}

interface DummyProjection {

String getName();
Expand Down
Loading

0 comments on commit f97434d

Please sign in to comment.