diff --git a/plugin-security.policy b/plugin-security.policy index 84248fb0ef..118b52a7b7 100644 --- a/plugin-security.policy +++ b/plugin-security.policy @@ -73,8 +73,7 @@ grant { //Java 9+ permission java.lang.RuntimePermission "accessClassInPackage.com.sun.jndi.*"; - //Enable this permission to debug unauthorized de-serialization attempt - //permission java.io.SerializablePermission "enableSubstitution"; + permission java.io.SerializablePermission "enableSubstitution"; }; grant codeBase "${codebase.netty-common}" { diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index 5b7f7d8ba0..19c21404d1 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -60,6 +60,8 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.regex.Pattern; import org.opensearch.OpenSearchException; @@ -76,6 +78,9 @@ public class Base64Helper { private static final Logger logger = LogManager.getLogger(Base64Helper.class); + private static final String ODFE_PACKAGE = "com.amazon.opendistroforelasticsearch"; + private static final String OS_PACKAGE = "org.opensearch"; + private static final Set> SAFE_CLASSES = ImmutableSet.of( String.class, SocketAddress.class, @@ -109,10 +114,69 @@ private static boolean isSafeClass(Class cls) { SAFE_ASSIGNABLE_FROM_CLASSES.stream().anyMatch(c -> c.isAssignableFrom(cls)); } + private static class DescriptorNameSetter { + private static final Field NAME = getField(); + + private DescriptorNameSetter() { + } + + private static Field getFieldPrivileged() { + try { + final Field field = ObjectStreamClass.class.getDeclaredField("name"); + field.setAccessible(true); + return field; + } catch (NoSuchFieldException | SecurityException e) { + logger.error("Failed to get ObjectStreamClass declared field", e); + if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else { + throw new RuntimeException(e); + } + } + } + + private static Field getField() { + SpecialPermission.check(); + return AccessController.doPrivileged((PrivilegedAction) () -> getFieldPrivileged()); + } + + public static void setName(ObjectStreamClass desc, String name) { + try { + logger.debug("replacing descriptor name from [{}] to [{}]", desc.getName(), name); + NAME.set(desc, name); + } catch (IllegalAccessException e) { + logger.error("Failed to replace descriptor name from {} to {}", desc.getName(), name, e); + throw new OpenSearchException(e); + } + } + } + + private static class DescriptorReplacer { + private final ConcurrentMap nameToDescriptor = new ConcurrentHashMap<>(); + + public ObjectStreamClass replace(final ObjectStreamClass desc) { + final String name = desc.getName(); + if (name.startsWith(OS_PACKAGE)) { + return nameToDescriptor.computeIfAbsent(name, s -> { + SpecialPermission.check(); + // we can't modify original descriptor as it is cached by ObjectStreamClass, create clone + final ObjectStreamClass clone = AccessController.doPrivileged( + (PrivilegedAction)() -> SerializationUtils.clone(desc) + ); + DescriptorNameSetter.setName(clone, s.replace(OS_PACKAGE, ODFE_PACKAGE)); + return clone; + }); + } + return desc; + } + } + private final static class SafeObjectOutputStream extends ObjectOutputStream { private static final boolean useSafeObjectOutputStream = checkSubstitutionPermission(); + private final DescriptorReplacer descriptorReplacer = new DescriptorReplacer(); + private static boolean checkSubstitutionPermission() { SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -143,7 +207,6 @@ static ObjectOutputStream create(ByteArrayOutputStream out) throws IOException { private SafeObjectOutputStream(OutputStream out) throws IOException { super(out); - //useProtocolVersion(PROTOCOL_VERSION_2); SecurityManager sm = System.getSecurityManager(); if (sm != null) { @@ -156,21 +219,8 @@ private SafeObjectOutputStream(OutputStream out) throws IOException { } @Override - protected void writeClassDescriptor(ObjectStreamClass desc) throws IOException { - if (desc.getName().equals(User.class.getName())) { - final Field name; - try { - desc = SerializationUtils.clone(desc); - name = desc.getClass().getDeclaredField("name"); - name.setAccessible(true); - name.set(desc, "com.amazon.opendistroforelasticsearch.security.user.User"); - logger.warn("Changed desc {}", desc); - } catch (ReflectiveOperationException e) { - logger.error("Failed to change desc {} name", desc, e); - } - //desc = ObjectStreamClass.lookup(com.amazon.opendistroforelasticsearch.security.user.User.class); - } - super.writeClassDescriptor(desc); + protected void writeClassDescriptor(final ObjectStreamClass desc) throws IOException { + super.writeClassDescriptor(descriptorReplacer.replace(desc)); } @Override @@ -230,12 +280,11 @@ protected Class resolveClass(ObjectStreamClass desc) throws IOException, Clas @Override protected ObjectStreamClass readClassDescriptor() throws IOException, ClassNotFoundException { ObjectStreamClass desc = super.readClassDescriptor(); - - if (desc.getName().equals("com.amazon.opendistroforelasticsearch.security.user.User")) { - desc = ObjectStreamClass.lookup(org.opensearch.security.user.User.class); - logger.warn("replaced class desc {}", desc); + final String name = desc.getName(); + if (name.startsWith(ODFE_PACKAGE)) { + desc = ObjectStreamClass.lookup(Class.forName(name.replace(ODFE_PACKAGE, OS_PACKAGE))); + logger.debug("replaced descriptor name from [{}] to [{}]", name, desc.getName()); } - return desc; } } diff --git a/src/main/java/org/opensearch/security/user/User.java b/src/main/java/org/opensearch/security/user/User.java index cc98db2565..8dedcf6946 100644 --- a/src/main/java/org/opensearch/security/user/User.java +++ b/src/main/java/org/opensearch/security/user/User.java @@ -67,7 +67,7 @@ public class User implements Serializable, Writeable, CustomAttributesAware { * roles == backend_roles */ private final Set roles = new HashSet(); - private final Set securityRoles = new HashSet(); + private final Set openDistroSecurityRoles = new HashSet(); private String requestedTenant; private Map attributes = new HashMap<>(); private boolean isInjected = false; @@ -78,7 +78,7 @@ public User(final StreamInput in) throws IOException { roles.addAll(in.readList(StreamInput::readString)); requestedTenant = in.readString(); attributes = in.readMap(StreamInput::readString, StreamInput::readString); - securityRoles.addAll(in.readList(StreamInput::readString)); + openDistroSecurityRoles.addAll(in.readList(StreamInput::readString)); } /** @@ -244,7 +244,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeStringCollection(new ArrayList(roles)); out.writeString(requestedTenant); out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); - out.writeStringCollection(securityRoles ==null?Collections.emptyList():new ArrayList(securityRoles)); + out.writeStringCollection(openDistroSecurityRoles ==null?Collections.emptyList():new ArrayList(openDistroSecurityRoles)); } /** @@ -260,12 +260,12 @@ public synchronized final Map getCustomAttributesMap() { } public final void addSecurityRoles(final Collection securityRoles) { - if(securityRoles != null && this.securityRoles != null) { - this.securityRoles.addAll(securityRoles); + if(securityRoles != null && this.openDistroSecurityRoles != null) { + this.openDistroSecurityRoles.addAll(securityRoles); } } public final Set getSecurityRoles() { - return this.securityRoles == null ? Collections.emptySet() : Collections.unmodifiableSet(this.securityRoles); + return this.openDistroSecurityRoles == null ? Collections.emptySet() : Collections.unmodifiableSet(this.openDistroSecurityRoles); } }