diff --git a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/support/SimpleJpaRepository.java b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/support/SimpleJpaRepository.java index 7ee27c1304..80bb93431b 100644 --- a/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/support/SimpleJpaRepository.java +++ b/spring-data-jpa/src/main/java/org/springframework/data/jpa/repository/support/SimpleJpaRepository.java @@ -40,8 +40,6 @@ import java.util.Optional; import java.util.function.BiConsumer; import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.StreamSupport; import org.springframework.data.domain.Example; import org.springframework.data.domain.KeysetScrollPosition; @@ -246,14 +244,8 @@ public void deleteAllByIdInBatch(Iterable ids) { /* * Some JPA providers require {@code ids} to be a {@link Collection} so we must convert if it's not already. */ - - if (ids instanceof Collection) { - query.setParameter("ids", ids); - } else { - Collection idsCollection = StreamSupport.stream(ids.spliterator(), false) - .collect(Collectors.toCollection(ArrayList::new)); - query.setParameter("ids", idsCollection); - } + Collection idCollection = toCollection(ids); + query.setParameter("ids", idCollection); applyQueryHints(query); @@ -414,7 +406,7 @@ public List findAllById(Iterable ids) { return results; } - Collection idCollection = Streamable.of(ids).toList(); + Collection idCollection = toCollection(ids); ByIdsSpecification specification = new ByIdsSpecification<>(entityInformation); TypedQuery query = getQuery(specification, Sort.unsorted()); @@ -918,6 +910,11 @@ private ProjectionFactory getProjectionFactory() { return projectionFactory; } + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static Collection toCollection(Iterable ids) { + return ids instanceof Collection c ? c : Streamable.of(ids).toList(); + } + /** * Executes a count query and transparently sums up all values returned. *