From b55a4d3908d5f08f237b9cb60a5e9dfe0480b3be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Nicoll?= Date: Tue, 30 Jan 2024 15:30:47 +0100 Subject: [PATCH] Prevent AOT from failing with spring-orm without JPA This commit improves PersistenceManagedTypesBeanRegistrationAotProcessor so that it does not attempt to load JPA classes when checking for the presence of a PersistenceManagedTypes bean. To make it more clear a check on the presence for JPA has been added to prevent the nested classes to be loaded regardless of the presence of the bean. Closes gh-32155 --- ...agedTypesBeanRegistrationAotProcessor.java | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java index 75c0834f19cb..481b37abea29 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/persistenceunit/PersistenceManagedTypesBeanRegistrationAotProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,34 +67,26 @@ @SuppressWarnings("unchecked") class PersistenceManagedTypesBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { - private static final List> CALLBACK_TYPES = List.of(PreUpdate.class, - PostUpdate.class, PrePersist.class, PostPersist.class, PreRemove.class, PostRemove.class, PostLoad.class); - - @Nullable - private static Class embeddableInstantiatorClass; - - static { - try { - embeddableInstantiatorClass = (Class) ClassUtils.forName("org.hibernate.annotations.EmbeddableInstantiator", - PersistenceManagedTypesBeanRegistrationAotProcessor.class.getClassLoader()); - } - catch (ClassNotFoundException ex) { - embeddableInstantiatorClass = null; - } - } - + private static final boolean jpaPresent = ClassUtils.isPresent("jakarta.persistence.Entity", + PersistenceManagedTypesBeanRegistrationAotProcessor.class.getClassLoader()); @Nullable @Override public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { - if (PersistenceManagedTypes.class.isAssignableFrom(registeredBean.getBeanClass())) { - return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments -> - new JpaManagedTypesBeanRegistrationCodeFragments(codeFragments, registeredBean)); + if (jpaPresent) { + if (PersistenceManagedTypes.class.isAssignableFrom(registeredBean.getBeanClass())) { + return BeanRegistrationAotContribution.withCustomCodeFragments(codeFragments -> + new JpaManagedTypesBeanRegistrationCodeFragments(codeFragments, registeredBean)); + } } return null; } - private static class JpaManagedTypesBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator { + private static final class JpaManagedTypesBeanRegistrationCodeFragments extends BeanRegistrationCodeFragmentsDecorator { + + private static final List> CALLBACK_TYPES = List.of(PreUpdate.class, + PostUpdate.class, PrePersist.class, PostPersist.class, PreRemove.class, PostRemove.class, PostLoad.class); + private static final ParameterizedTypeName LIST_OF_STRINGS_TYPE = ParameterizedTypeName.get(List.class, String.class); @@ -102,7 +94,7 @@ private static class JpaManagedTypesBeanRegistrationCodeFragments extends BeanRe private final BindingReflectionHintsRegistrar bindingRegistrar = new BindingReflectionHintsRegistrar(); - public JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments, + private JpaManagedTypesBeanRegistrationCodeFragments(BeanRegistrationCodeFragments codeFragments, RegisteredBean registeredBean) { super(codeFragments); this.registeredBean = registeredBean; @@ -114,7 +106,8 @@ public CodeBlock generateInstanceSupplierCode(GenerationContext generationContex boolean allowDirectSupplierShortcut) { PersistenceManagedTypes persistenceManagedTypes = this.registeredBean.getBeanFactory() .getBean(this.registeredBean.getBeanName(), PersistenceManagedTypes.class); - contributeHints(generationContext.getRuntimeHints(), persistenceManagedTypes.getManagedClassNames()); + contributeHints(generationContext.getRuntimeHints(), + this.registeredBean.getBeanFactory().getBeanClassLoader(), persistenceManagedTypes.getManagedClassNames()); GeneratedMethod generatedMethod = beanRegistrationCode.getMethods() .add("getInstance", method -> { Class beanType = PersistenceManagedTypes.class; @@ -135,7 +128,7 @@ private CodeBlock toCodeBlock(List values) { return CodeBlock.join(values.stream().map(value -> CodeBlock.of("$S", value)).toList(), ", "); } - private void contributeHints(RuntimeHints hints, List managedClassNames) { + private void contributeHints(RuntimeHints hints, @Nullable ClassLoader classLoader, List managedClassNames) { for (String managedClassName : managedClassNames) { try { Class managedClass = ClassUtils.forName(managedClassName, null); @@ -144,7 +137,7 @@ private void contributeHints(RuntimeHints hints, List managedClassNames) contributeIdClassHints(hints, managedClass); contributeConverterHints(hints, managedClass); contributeCallbackHints(hints, managedClass); - contributeHibernateHints(hints, managedClass); + contributeHibernateHints(hints, classLoader, managedClass); } catch (ClassNotFoundException ex) { throw new IllegalArgumentException("Failed to instantiate the managed class: " + managedClassName, ex); @@ -194,7 +187,8 @@ private void contributeCallbackHints(RuntimeHints hints, Class managedClass) } @SuppressWarnings("unchecked") - private void contributeHibernateHints(RuntimeHints hints, Class managedClass) { + private void contributeHibernateHints(RuntimeHints hints, @Nullable ClassLoader classLoader, Class managedClass) { + Class embeddableInstantiatorClass = loadEmbeddableInstantiatorClass(classLoader); if (embeddableInstantiatorClass == null) { return; } @@ -216,5 +210,16 @@ private void registerInstantiatorForReflection(ReflectionHints reflection, @Null Class embeddableInstantiatorClass = (Class) AnnotationUtils.getAnnotationAttributes(annotation).get("value"); reflection.registerType(embeddableInstantiatorClass, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS); } + + @Nullable + private static Class loadEmbeddableInstantiatorClass(@Nullable ClassLoader classLoader) { + try { + return (Class) ClassUtils.forName( + "org.hibernate.annotations.EmbeddableInstantiator", classLoader); + } + catch (ClassNotFoundException ex) { + return null; + } + } } }