Skip to content

Commit

Permalink
Prevent AOT from failing with spring-orm without JPA
Browse files Browse the repository at this point in the history
This commit improves PersistenceManagedTypesBeanRegistrationAotProcessor
so that it does not attempt to load JPA classes when checking for the
presence of a PersistenceManagedTypes bean. To make it more clear a
check on the presence for JPA has been added to prevent the nested
classes to be loaded regardless of the presence of the bean.

Closes gh-32155
  • Loading branch information
snicoll committed Jan 30, 2024
1 parent db53586 commit b55a4d3
Showing 1 changed file with 31 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2002-2023 the original author or authors.
* Copyright 2002-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -67,42 +67,34 @@
@SuppressWarnings("unchecked")
class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor {

private static final List<Class<? extends Annotation>> CALLBACK_TYPES = List.of(PreUpdate.class,
PostUpdate.class, PrePersist.class, PostPersist.class, PreRemove.class, PostRemove.class, PostLoad.class);

@Nullable
private static Class<? extends Annotation> embeddableInstantiatorClass;

static {
try {
embeddableInstantiatorClass = (Class<? extends Annotation>) ClassUtils.forName("org.hibernate.annotations.EmbeddableInstantiator",
PersistenceManagedTypesBeanRegistrationAotProcessor.class.getClassLoader());
}
catch (ClassNotFoundException ex) {
embeddableInstantiatorClass = null;
}
}

private static final boolean jpaPresent = ClassUtils.isPresent("jakarta.persistence.Entity",
PersistenceManagedTypesBeanRegistrationAotProcessor.class.getClassLoader());

@Nullable
@Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
if (PersistenceManagedTypes.class.isAssignableFrom(registeredBean.getBeanClass())) {
return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments ->
new JpaManagedTypesBeanRegistrationCodeFragments(codeFragments, registeredBean));
if (jpaPresent) {
if (PersistenceManagedTypes.class.isAssignableFrom(registeredBean.getBeanClass())) {
return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments ->
new JpaManagedTypesBeanRegistrationCodeFragments(codeFragments, registeredBean));
}
}
return null;
}

private static class JpaManagedTypesBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator {
private static final class JpaManagedTypesBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator {

private static final List<Class<? extends Annotation>> CALLBACK_TYPES = List.of(PreUpdate.class,
PostUpdate.class, PrePersist.class, PostPersist.class, PreRemove.class, PostRemove.class, PostLoad.class);


private static final ParameterizedTypeName LIST_OF_STRINGS_TYPE = ParameterizedTypeName.get(List.class, String.class);

private final RegisteredBean registeredBean;

private final BindingReflectionHintsRegistrar bindingRegistrar = new BindingReflectionHintsRegistrar();

public JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments,
private JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments,
RegisteredBean registeredBean) {
super(codeFragments);
this.registeredBean = registeredBean;
Expand All @@ -114,7 +106,8 @@ public CodeBlock generateInstanceSupplierCode(GenerationContext generationContex
boolean allowDirectSupplierShortcut) {
PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory()
.getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class);
contributeHints(generationContext.getRuntimeHints(), persistenceManagedTypes.getManagedClassNames());
contributeHints(generationContext.getRuntimeHints(),
this.registeredBean.getBeanFactory().getBeanClassLoader(), persistenceManagedTypes.getManagedClassNames());
GeneratedMethod generatedMethod = beanRegistrationCode.getMethods()
.add("getInstance", method -> {
Class<?> beanType = PersistenceManagedTypes.class;
Expand All @@ -135,7 +128,7 @@ private CodeBlock toCodeBlock(List<String> values) {
return CodeBlock.join(values.stream().map(value -> CodeBlock.of("$S", value)).toList(), ", ");
}

private void contributeHints(RuntimeHints hints, List<String> managedClassNames) {
private void contributeHints(RuntimeHints hints, @Nullable ClassLoader classLoader, List<String> managedClassNames) {
for (String managedClassName : managedClassNames) {
try {
Class<?> managedClass = ClassUtils.forName(managedClassName, null);
Expand All @@ -144,7 +137,7 @@ private void contributeHints(RuntimeHints hints, List<String> managedClassNames)
contributeIdClassHints(hints, managedClass);
contributeConverterHints(hints, managedClass);
contributeCallbackHints(hints, managedClass);
contributeHibernateHints(hints, managedClass);
contributeHibernateHints(hints, classLoader, managedClass);
}
catch (ClassNotFoundException ex) {
throw new IllegalArgumentException("Failed to instantiate the managed class: " + managedClassName, ex);
Expand Down Expand Up @@ -194,7 +187,8 @@ private void contributeCallbackHints(RuntimeHints hints, Class<?> managedClass)
}

@SuppressWarnings("unchecked")
private void contributeHibernateHints(RuntimeHints hints, Class<?> managedClass) {
private void contributeHibernateHints(RuntimeHints hints, @Nullable ClassLoader classLoader, Class<?> managedClass) {
Class<? extends Annotation> embeddableInstantiatorClass = loadEmbeddableInstantiatorClass(classLoader);
if (embeddableInstantiatorClass == null) {
return;
}
Expand All @@ -216,5 +210,16 @@ private void registerInstantiatorForReflection(ReflectionHints reflection, @Null
Class<?> embeddableInstantiatorClass = (Class<?>) AnnotationUtils.getAnnotationAttributes(annotation).get("value");
reflection.registerType(embeddableInstantiatorClass, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS);
}

@Nullable
private static Class<? extends Annotation> loadEmbeddableInstantiatorClass(@Nullable ClassLoader classLoader) {
try {
return (Class<? extends Annotation>) ClassUtils.forName(
"org.hibernate.annotations.EmbeddableInstantiator", classLoader);
}
catch (ClassNotFoundException ex) {
return null;
}
}
}
}

0 comments on commit b55a4d3

Please sign in to comment.