Skip to content

Commit

Permalink
Ensure conversion/aggregation is only performed once per invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
marcphilipp committed Feb 27, 2025
1 parent d02b53b commit 6de9f4e
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class ContainerTemplateConstructorParameterResolver extends ParameterizedInvocat
private final Class<?> containerTemplateClass;

ContainerTemplateConstructorParameterResolver(ParameterizedContainerClassContext classContext,
EvaluatedArgumentSet arguments, int invocationIndex) {
super(classContext.getResolverFacade(), arguments, invocationIndex);
EvaluatedArgumentSet arguments, int invocationIndex, ResolutionCache resolutionCache) {
super(classContext.getResolverFacade(), arguments, invocationIndex, resolutionCache);
this.containerTemplateClass = classContext.getAnnotatedElement();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@

class ContainerTemplateInstanceFieldInjectingBeforeEachCallback implements BeforeEachCallback {

private final ParameterizedContainerClassContext classContext;
private final ResolverFacade resolverFacade;
private final EvaluatedArgumentSet arguments;
private final int invocationIndex;
private final ResolutionCache resolutionCache;

ContainerTemplateInstanceFieldInjectingBeforeEachCallback(ParameterizedContainerClassContext classContext,
EvaluatedArgumentSet arguments, int invocationIndex) {
this.classContext = classContext;
ContainerTemplateInstanceFieldInjectingBeforeEachCallback(ResolverFacade resolverFacade,
EvaluatedArgumentSet arguments, int invocationIndex, ResolutionCache resolutionCache) {
this.resolverFacade = resolverFacade;
this.arguments = arguments;
this.invocationIndex = invocationIndex;
this.resolutionCache = resolutionCache;
}

@Override
public void beforeEach(ExtensionContext extensionContext) throws Exception {
extensionContext.getTestInstance() //
.ifPresent(testInstance -> this.classContext.getResolverFacade() //
.resolveAndInjectFields(testInstance, extensionContext, this.arguments, this.invocationIndex));
.ifPresent(testInstance -> this.resolverFacade //
.resolveAndInjectFields(testInstance, extensionContext, this.arguments, this.invocationIndex,
this.resolutionCache));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@

class ContainerTemplateInstanceFieldInjectingPostProcessor implements TestInstancePostProcessor {

private final ParameterizedContainerClassContext classContext;
private final ResolverFacade resolverFacade;
private final EvaluatedArgumentSet arguments;
private final int invocationIndex;
private final ResolutionCache resolutionCache;

ContainerTemplateInstanceFieldInjectingPostProcessor(ParameterizedContainerClassContext classContext,
EvaluatedArgumentSet arguments, int invocationIndex) {
this.classContext = classContext;
ContainerTemplateInstanceFieldInjectingPostProcessor(ResolverFacade resolverFacade, EvaluatedArgumentSet arguments,
int invocationIndex, ResolutionCache resolutionCache) {
this.resolverFacade = resolverFacade;
this.arguments = arguments;
this.invocationIndex = invocationIndex;
this.resolutionCache = resolutionCache;
}

@Override
Expand All @@ -33,8 +35,7 @@ public ExtensionContextScope getTestInstantiationExtensionContextScope(Extension

@Override
public void postProcessTestInstance(Object testInstance, ExtensionContext extensionContext) {
this.classContext.getResolverFacade() //
.resolveAndInjectFields(testInstance, extensionContext, this.arguments, this.invocationIndex);
this.resolverFacade.resolveAndInjectFields(testInstance, extensionContext, this.arguments, this.invocationIndex,
this.resolutionCache);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
class ParameterizedContainerInvocationContext extends ParameterizedInvocationContext<ParameterizedContainerClassContext>
implements ContainerTemplateInvocationContext {

private final ResolutionCache resolutionCache = ResolutionCache.enabled();

ParameterizedContainerInvocationContext(ParameterizedContainerClassContext classContext,
ParameterizedInvocationNameFormatter formatter, Arguments arguments, int invocationIndex) {
super(classContext, formatter, arguments, invocationIndex);
Expand All @@ -53,18 +55,19 @@ private ContainerTemplateConstructorParameterResolver createExtensionForConstruc
Preconditions.condition(this.declarationContext.getTestInstanceLifecycle() == PER_METHOD,
"Constructor injection is only supported for lifecycle PER_METHOD");
return new ContainerTemplateConstructorParameterResolver(this.declarationContext, this.arguments,
this.invocationIndex);
this.invocationIndex, this.resolutionCache);
}

private Extension createExtensionForFieldInjection() {
ResolverFacade resolverFacade = this.declarationContext.getResolverFacade();
TestInstance.Lifecycle lifecycle = this.declarationContext.getTestInstanceLifecycle();
switch (lifecycle) {
case PER_CLASS:
return new ContainerTemplateInstanceFieldInjectingBeforeEachCallback(this.declarationContext,
this.arguments, this.invocationIndex);
return new ContainerTemplateInstanceFieldInjectingBeforeEachCallback(resolverFacade, this.arguments,
this.invocationIndex, this.resolutionCache);
case PER_METHOD:
return new ContainerTemplateInstanceFieldInjectingPostProcessor(this.declarationContext, this.arguments,
this.invocationIndex);
return new ContainerTemplateInstanceFieldInjectingPostProcessor(resolverFacade, this.arguments,
this.invocationIndex, this.resolutionCache);
}
throw new JUnitException("Unsupported lifecycle: " + lifecycle);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ abstract class ParameterizedInvocationParameterResolver implements ParameterReso
private final ResolverFacade resolverFacade;
private final EvaluatedArgumentSet arguments;
private final int invocationIndex;
private final ResolutionCache resolutionCache;

ParameterizedInvocationParameterResolver(ResolverFacade resolverFacade, EvaluatedArgumentSet arguments,
int invocationIndex) {
int invocationIndex, ResolutionCache resolutionCache) {

this.resolverFacade = resolverFacade;
this.arguments = arguments;
this.invocationIndex = invocationIndex;
this.resolutionCache = resolutionCache;
}

@Override
Expand All @@ -51,7 +53,8 @@ public final boolean supportsParameter(ParameterContext parameterContext, Extens
public final Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
throws ParameterResolutionException {

return this.resolverFacade.resolve(parameterContext, extensionContext, this.arguments, this.invocationIndex);
return this.resolverFacade.resolve(parameterContext, extensionContext, this.arguments, this.invocationIndex,
this.resolutionCache);
}

protected abstract boolean isSupportedOnConstructorOrMethod(Executable declaringExecutable,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2015-2025 the original author or authors.
*
* All rights reserved. This program and the accompanying materials are
* made available under the terms of the Eclipse Public License v2.0 which
* accompanies this distribution and is available at
*
* https://www.eclipse.org/legal/epl-v20.html
*/

package org.junit.jupiter.params;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

import org.junit.jupiter.params.support.ParameterDeclaration;

/**
* @since 5.13
*/
interface ResolutionCache {

static ResolutionCache enabled() {
return new Concurrent();
}

ResolutionCache DISABLED = (__, resolver) -> resolver.get();

Object resolve(ParameterDeclaration declaration, Supplier<Object> resolver);

class Concurrent implements ResolutionCache {

private final Map<ParameterDeclaration, Object> cache = new ConcurrentHashMap<>();

@Override
public Object resolve(ParameterDeclaration declaration, Supplier<Object> resolver) {
return cache.computeIfAbsent(declaration, __ -> resolver.get());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -220,25 +220,28 @@ int determineConsumedArgumentCount(EvaluatedArgumentSet arguments) {
* arguments.
*/
Object resolve(ParameterContext parameterContext, ExtensionContext extensionContext, EvaluatedArgumentSet arguments,
int invocationIndex) {
int invocationIndex, ResolutionCache resolutionCache) {

int parameterIndex = toLogicalIndex(parameterContext);
ParameterDeclaration declaration = this.indexedParameterDeclarations.get(parameterIndex) //
.orElseGet(() -> this.aggregatorParameters.stream().filter(
it -> it.getParameterIndex() == parameterIndex).findFirst() //
.orElseThrow(() -> new ParameterResolutionException(
"Parameter index out of bounds: " + parameterIndex)));
return getResolver(extensionContext, declaration, parameterContext.getParameter()) //
.resolve(parameterContext, parameterIndex, arguments, invocationIndex);
return resolutionCache.resolve(declaration,
() -> getResolver(extensionContext, declaration, parameterContext.getParameter()) //
.resolve(parameterContext, parameterIndex, arguments, invocationIndex));
}

void resolveAndInjectFields(Object testInstance, ExtensionContext extensionContext, EvaluatedArgumentSet arguments,
int invocationIndex) {
int invocationIndex, ResolutionCache resolutionCache) {

if (this.indexedParameterDeclarations.sourceElement.equals(testInstance.getClass())) {
getAllParameterDeclarations() //
.filter(FieldParameterDeclaration.class::isInstance) //
.map(FieldParameterDeclaration.class::cast) //
.forEach(declaration -> setField(testInstance, declaration, extensionContext, arguments,
invocationIndex));
invocationIndex, resolutionCache));
}
}

Expand All @@ -247,15 +250,16 @@ private Stream<ParameterDeclaration> getAllParameterDeclarations() {
aggregatorParameters.stream());
}

private void setField(Object testInstance, FieldParameterDeclaration parameterDeclaration,
ExtensionContext extensionContext, EvaluatedArgumentSet arguments, int invocationIndex) {
Object argument = resolve(parameterDeclaration, extensionContext, arguments, invocationIndex);
private void setField(Object testInstance, FieldParameterDeclaration declaration, ExtensionContext extensionContext,
EvaluatedArgumentSet arguments, int invocationIndex, ResolutionCache resolutionCache) {

Object argument = resolutionCache.resolve(declaration,
() -> resolve(declaration, extensionContext, arguments, invocationIndex));
try {
parameterDeclaration.getField().set(testInstance, argument);
declaration.getField().set(testInstance, argument);
}
catch (Exception e) {
throw new JUnitException("Failed to inject parameter value into field: " + parameterDeclaration.getField(),
e);
throw new JUnitException("Failed to inject parameter value into field: " + declaration.getField(), e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class TestTemplateMethodParameterResolver extends ParameterizedInvocationParamet

TestTemplateMethodParameterResolver(ParameterizedTestMethodContext methodContext, EvaluatedArgumentSet arguments,
int invocationIndex) {
super(methodContext.getResolverFacade(), arguments, invocationIndex);
super(methodContext.getResolverFacade(), arguments, invocationIndex, ResolutionCache.DISABLED);
this.testTemplateMethod = methodContext.getAnnotatedElement();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS;
Expand Down Expand Up @@ -48,6 +49,8 @@
import java.util.stream.Stream;

import org.assertj.core.api.Condition;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Named;
import org.junit.jupiter.api.Nested;
Expand All @@ -66,6 +69,7 @@
import org.junit.jupiter.params.aggregator.SimpleArgumentsAggregator;
import org.junit.jupiter.params.converter.ArgumentConversionException;
import org.junit.jupiter.params.converter.ConvertWith;
import org.junit.jupiter.params.converter.SimpleArgumentConverter;
import org.junit.jupiter.params.converter.TypedArgumentConverter;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
Expand Down Expand Up @@ -449,6 +453,16 @@ void declaredIndexMustBeUnique() {
containerTemplateClass.getName(), containerTemplateClass.getName()))));
}

@ParameterizedTest
@ValueSource(classes = { ArgumentConversionPerInvocationConstructorInjectionTestCase.class,
ArgumentConversionPerInvocationFieldInjectionTestCase.class })
void argumentConverterIsOnlyCalledOncePerInvocation(Class<?> containerTemplateClass) {

var results = executeTestsForClass(containerTemplateClass);

results.allEvents().assertStatistics(stats -> stats.started(5).succeeded(5));
}

// -------------------------------------------------------------------

private static Stream<String> invocationDisplayNames(EngineExecutionResults results) {
Expand Down Expand Up @@ -1289,4 +1303,82 @@ void test(TestReporter reporter) {
));
}
}

@ParameterizedContainer
@ValueSource(ints = 1)
record ArgumentConversionPerInvocationConstructorInjectionTestCase(
@ConvertWith(Wrapper.Converter.class) Wrapper wrapper) {

static Wrapper instance;

@BeforeAll
@AfterAll
static void clearWrapper() {
instance = null;
}

@Test
void test1() {
setOrCheckWrapper();
}

@Test
void test2() {
setOrCheckWrapper();
}

private void setOrCheckWrapper() {
if (instance == null) {
instance = wrapper;
}
else {
assertSame(instance, wrapper);
}
}
}

@ParameterizedContainer
@ValueSource(ints = 1)
static class ArgumentConversionPerInvocationFieldInjectionTestCase {

static Wrapper instance;

@BeforeAll
@AfterAll
static void clearWrapper() {
instance = null;
}

@Parameter
@ConvertWith(Wrapper.Converter.class)
Wrapper wrapper;

@Test
void test1() {
setOrCheckWrapper();
}

@Test
void test2() {
setOrCheckWrapper();
}

private void setOrCheckWrapper() {
if (instance == null) {
instance = wrapper;
}
else {
assertSame(instance, wrapper);
}
}
}

record Wrapper(int value) {
static class Converter extends SimpleArgumentConverter {
@Override
protected Object convert(Object source, Class<?> targetType) {
return new Wrapper((Integer) source);
}
}
}
}

0 comments on commit 6de9f4e

Please sign in to comment.