diff --git a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java index 907d605860..d5524e063c 100755 --- a/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java +++ b/src/main/java/com/amazon/dlic/auth/ldap/LdapUser.java @@ -11,6 +11,7 @@ package com.amazon.dlic.auth.ldap; +import java.io.IOException; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -20,6 +21,8 @@ import com.amazon.dlic.auth.ldap.util.Utils; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.security.support.WildcardMatcher; import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.user.User; @@ -45,6 +48,12 @@ public LdapUser( attributes.putAll(extractLdapAttributes(originalUsername, userEntry, customAttrMaxValueLen, allowlistedCustomLdapAttrMatcher)); } + public LdapUser(StreamInput in) throws IOException { + super(in); + userEntry = null; + originalUsername = in.readString(); + } + /** * May return null because ldapEntry is transient * @@ -88,4 +97,10 @@ public static Map extractLdapAttributes( } return Collections.unmodifiableMap(attributes); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(originalUsername); + } } diff --git a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java index e4aa062641..945a38b545 100644 --- a/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java +++ b/src/main/java/org/opensearch/security/auditlog/impl/AbstractAuditLog.java @@ -584,7 +584,7 @@ private Origin getOrigin() { private TransportAddress getRemoteAddress() { TransportAddress address = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS); if(address == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER) != null) { - address = new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER))); + address = new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER), threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION))); } return address; } @@ -592,7 +592,7 @@ private TransportAddress getRemoteAddress() { private String getUser() { User user = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); if(user == null && threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER) != null) { - user = (User) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER)); + user = (User) Base64Helper.deserializeObject(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } return user==null?null:user.getName(); } diff --git a/src/main/java/org/opensearch/security/auth/UserInjector.java b/src/main/java/org/opensearch/security/auth/UserInjector.java index 9253023eb3..8ce419f808 100644 --- a/src/main/java/org/opensearch/security/auth/UserInjector.java +++ b/src/main/java/org/opensearch/security/auth/UserInjector.java @@ -26,6 +26,7 @@ package org.opensearch.security.auth; +import java.io.IOException; import java.io.ObjectStreamException; import java.net.InetAddress; import java.net.UnknownHostException; @@ -36,6 +37,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.settings.Settings; import org.opensearch.common.transport.TransportAddress; import org.opensearch.rest.RestRequest; @@ -63,13 +66,18 @@ public UserInjector(Settings settings, ThreadPool threadPool, AuditLog auditLog, } - static class InjectedUser extends User { + public static class InjectedUser extends User { private transient TransportAddress transportAddress; public InjectedUser(String name) { super(name); } + public InjectedUser(StreamInput in) throws IOException { + super(in); + this.setInjected(true); + } + private Object writeReplace() throws ObjectStreamException { User user = new User(getName()); user.addRoles(getRoles()); @@ -96,6 +104,11 @@ public void setTransportAddress(String addr) throws UnknownHostException, Illega this.transportAddress = new TransportAddress(iAdress, port); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + } } public InjectedUser getInjectedUser() { diff --git a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java index 947557d342..3f1b00eb21 100644 --- a/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java +++ b/src/main/java/org/opensearch/security/configuration/DlsFlsValveImpl.java @@ -389,7 +389,7 @@ private void setDlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER) != null) { - Object deserializedDlsQueries = Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER)); + Object deserializedDlsQueries = Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); if (!dlsQueries.equals(deserializedDlsQueries)) { throw new OpenSearchSecurityException(ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER + " does not match (SG 900D)"); } @@ -437,7 +437,7 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER) != null) { - if (!maskedFieldsMap.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER)))) { + if (!maskedFieldsMap.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)))) { throw new OpenSearchSecurityException(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER + " does not match (SG 901D)"); } else { if (log.isDebugEnabled()) { @@ -463,9 +463,9 @@ private void setFlsHeaders(EvaluatedDlsFlsConfig dlsFls, ActionRequest request) } } else { if (threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER) != null) { - if (!flsFields.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER)))) { + if (!flsFields.equals(Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)))) { throw new OpenSearchSecurityException(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER + " does not match (SG 901D) " + flsFields - + "---" + Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER))); + + "---" + Base64Helper.deserializeObject(threadContext.getHeader(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER), threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION))); } else { if (log.isDebugEnabled()) { log.debug(ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER + " already set"); diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index b20cadfc96..daa4934d15 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -49,6 +49,8 @@ import java.util.regex.Pattern; import com.google.common.base.Preconditions; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.io.BaseEncoding; @@ -61,7 +63,12 @@ import org.opensearch.OpenSearchException; import org.opensearch.SpecialPermission; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.Writeable; import org.opensearch.core.common.Strings; +import org.opensearch.security.auth.UserInjector; import org.opensearch.security.user.User; public class Base64Helper { @@ -88,10 +95,40 @@ public class Base64Helper { Enum.class ); + + private enum CustomSerializationFormat { + + WRITEABLE(1), + STREAMABLE(2), + GENERIC(3); + + private final int id; + + CustomSerializationFormat(int id) { + this.id = id; + } + + static CustomSerializationFormat fromId(int id) { + switch (id) { + case 1: return WRITEABLE; + case 2: return STREAMABLE; + case 3: return GENERIC; + default: throw new IllegalArgumentException(String.format("%d is not a valid id", id)); + } + } + + } + + private static final ThreadLocal, Integer>> writeableClassToIdMap = ThreadLocal.withInitial(HashBiMap::create); + private static final StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); private static final Set SAFE_CLASS_NAMES = Collections.singleton( "org.ldaptive.LdapAttribute$LdapAttributeValues" ); + static { + registerAllWriteables(); + } + private static boolean isSafeClass(Class cls) { return cls.isArray() || SAFE_CLASSES.contains(cls) || @@ -156,7 +193,7 @@ protected Object replaceObject(Object obj) throws IOException { } } - public static String serializeObject(final Serializable object) { + private static String serializeObjectJDK(final Serializable object) { Preconditions.checkArgument(object != null, "object must not be null"); @@ -170,7 +207,47 @@ public static String serializeObject(final Serializable object) { return BaseEncoding.base64().encode(bytes); } - public static Serializable deserializeObject(final String string) { + private static String serializeObjectCustom(final Serializable object) { + + Preconditions.checkArgument(object != null, "object must not be null"); + final BytesStreamOutput streamOutput = new BytesStreamOutput(128); + Class clazz = object.getClass(); + try { + CustomSerializationFormat customSerializationFormat = getCustomSerializationMode(clazz); + switch (customSerializationFormat) { + case WRITEABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.WRITEABLE.id); + streamOutput.writeByte((byte) getWriteableClassID(clazz).intValue()); + ((Writeable) object).writeTo(streamOutput); + break; + case STREAMABLE: + streamOutput.writeByte((byte) CustomSerializationFormat.STREAMABLE.id); + streamableRegistry.writeTo(streamOutput, object); + break; + case GENERIC: + streamOutput.writeByte((byte) CustomSerializationFormat.GENERIC.id); + streamOutput.writeGenericValue(object); + break; + default: + throw new IllegalArgumentException(String.format("Could not determine custom serialization mode for class %s", clazz.getName())); + } + } catch (final Exception e) { + throw new OpenSearchException("Instance {} of class {} is not serializable", e, object, object.getClass()); + } + final byte[] bytes = streamOutput.bytes().toBytesRef().bytes; + streamOutput.close(); + return BaseEncoding.base64().encode(bytes); + } + + public static String serializeObject(final Serializable object, final boolean useJDKSerialization) { + return useJDKSerialization ? serializeObjectJDK(object) : serializeObjectCustom(object); + } + + public static String serializeObject(final Serializable object) { + return serializeObjectCustom(object); + } + + private static Serializable deserializeObjectJDK(final String string) { Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty"); @@ -183,6 +260,37 @@ public static Serializable deserializeObject(final String string) { } } + private static Serializable deserializeObjectCustom(final String string) { + + Preconditions.checkArgument(!Strings.isNullOrEmpty(string), "string must not be null or empty"); + final byte[] bytes = BaseEncoding.base64().decode(string); + try (final BytesStreamInput streamInput = new BytesStreamInput(bytes)) { + CustomSerializationFormat serializationFormat = CustomSerializationFormat.fromId(streamInput.readByte()); + switch (serializationFormat) { + case WRITEABLE: + final int classId = streamInput.readByte(); + Class clazz = getWriteableClassFromId(classId); + return (Serializable) clazz.getConstructor(StreamInput.class).newInstance(streamInput); + case STREAMABLE: + return (Serializable) streamableRegistry.readFrom(streamInput); + case GENERIC: + return (Serializable) streamInput.readGenericValue(); + default: + throw new IllegalArgumentException("Could not determine custom deserialization mode"); + } + } catch (final Exception e) { + throw new OpenSearchException(e); + } + } + + public static Serializable deserializeObject(final String string) { + return deserializeObjectCustom(string); + } + + public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) { + return useJDKDeserialization ? deserializeObjectJDK(string) : deserializeObjectCustom(string); + } + private final static class SafeObjectInputStream extends ObjectInputStream { public SafeObjectInputStream(InputStream in) throws IOException { @@ -200,4 +308,66 @@ protected Class resolveClass(ObjectStreamClass desc) throws IOException, Clas throw new InvalidClassException("Unauthorized deserialization attempt ", clazz.getName()); } } + + private static boolean isWriteable(Class clazz) { + return Writeable.class.isAssignableFrom(clazz); + } + + /** + * Returns integer ID for the registered Writeable class + *
+ * Protected for testing + */ + protected static Integer getWriteableClassID(Class clazz) { + if ( !isWriteable(clazz) ) { + throw new OpenSearchException("clazz should implement Writeable ", clazz); + } + if( !writeableClassToIdMap.get().containsKey(clazz) ) { + throw new OpenSearchException("Writeable clazz not registered ", clazz); + } + return writeableClassToIdMap.get().get(clazz); + } + + private static Class getWriteableClassFromId(int id) { + return writeableClassToIdMap.get().inverse().get(id); + } + + /** + * Registers the given Writeable class for custom serialization by assigning an incrementing integer ID + * IDs are stored in two thread local maps + * @param clazz class to be registered + */ + private static void registerWriteable(Class clazz) { + if ( writeableClassToIdMap.get().containsKey(clazz) ) { + throw new OpenSearchException("writeable clazz is already registered ", clazz.getName()); + } + int id = writeableClassToIdMap.get().size() + 1; + writeableClassToIdMap.get().put(clazz, id); + } + + /** + * Registers all Writeable classes for custom serialization support. + * Removing existing classes / changing order of registration will cause a breaking change in the serialization protocol + * as registerWriteable assigns an incrementing integer ID to each of the classes in the order it is called + * starting from 1. + *
+ * New classes can safely be added towards the end. + */ + private static void registerAllWriteables() { + registerWriteable(User.class); + registerWriteable(LdapUser.class); + registerWriteable(UserInjector.InjectedUser.class); + registerWriteable(SourceFieldsContext.class); + } + + private static CustomSerializationFormat getCustomSerializationMode(Class clazz) { + if ( isWriteable(clazz) ) { + return CustomSerializationFormat.WRITEABLE; + } else if (streamableRegistry.isStreamable(clazz) ) { + return CustomSerializationFormat.STREAMABLE; + } else { + return CustomSerializationFormat.GENERIC; + } + } + } diff --git a/src/main/java/org/opensearch/security/support/ConfigConstants.java b/src/main/java/org/opensearch/security/support/ConfigConstants.java index f5c64bccd3..dc93b0d23c 100644 --- a/src/main/java/org/opensearch/security/support/ConfigConstants.java +++ b/src/main/java/org/opensearch/security/support/ConfigConstants.java @@ -79,6 +79,8 @@ public class ConfigConstants { public static final String OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER = OPENDISTRO_SECURITY_CONFIG_PREFIX+"initial_action_class_header"; + public static final String OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT = OPENDISTRO_SECURITY_CONFIG_PREFIX+"source_field_context"; + /** * Set by SSL plugin for https requests only */ @@ -296,6 +298,8 @@ public enum RolesMappingResolution { public static final String TENANCY_GLOBAL_TENANT_NAME = "global"; public static final String TENANCY_GLOBAL_TENANT_DEFAULT_NAME = ""; + public static final String USE_JDK_SERIALIZATION = "plugins.security.use_jdk_serialization"; + public static Set getSettingAsSet(final Settings settings, final String key, final List defaultList, final boolean ignoreCaseForNone) { final List list = settings.getAsList(key, defaultList); if (list.size() == 1 && "NONE".equals(ignoreCaseForNone? list.get(0).toUpperCase() : list.get(0))) { diff --git a/src/main/java/org/opensearch/security/support/HeaderHelper.java b/src/main/java/org/opensearch/security/support/HeaderHelper.java index af8da305d4..f825d73973 100644 --- a/src/main/java/org/opensearch/security/support/HeaderHelper.java +++ b/src/main/java/org/opensearch/security/support/HeaderHelper.java @@ -27,6 +27,8 @@ package org.opensearch.security.support; import java.io.Serializable; +import java.util.Arrays; +import java.util.List; import com.google.common.base.Strings; @@ -68,7 +70,7 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context final String objectAsBase64 = getSafeFromHeader(context, headerName); if (!Strings.isNullOrEmpty(objectAsBase64)) { - return Base64Helper.deserializeObject(objectAsBase64); + return Base64Helper.deserializeObject(objectAsBase64, context.getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } return null; @@ -77,4 +79,16 @@ public static Serializable deserializeSafeFromHeader(final ThreadContext context public static boolean isTrustedClusterRequest(final ThreadContext context) { return context.getTransient(ConfigConstants.OPENDISTRO_SECURITY_SSL_TRANSPORT_TRUSTED_CLUSTER_REQUEST) == Boolean.TRUE; } + + public static List getAllSerializedHeaderNames() { + return Arrays.asList( + ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_FLS_FIELDS_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_DLS_FILTER_LEVEL_QUERY_HEADER, + ConfigConstants.OPENDISTRO_SECURITY_SOURCE_FIELD_CONTEXT + ); + } } diff --git a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java index 8f61bcbaf5..e298a81ef4 100644 --- a/src/main/java/org/opensearch/security/support/SourceFieldsContext.java +++ b/src/main/java/org/opensearch/security/support/SourceFieldsContext.java @@ -26,13 +26,18 @@ package org.opensearch.security.support; +import java.io.IOException; import java.io.Serializable; import java.util.Arrays; +import java.util.Objects; import org.opensearch.action.get.GetRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; -public class SourceFieldsContext implements Serializable { +public class SourceFieldsContext implements Serializable, Writeable { private String[] includes; private String[] excludes; @@ -72,6 +77,18 @@ public SourceFieldsContext(SearchRequest request) { //} } + public SourceFieldsContext(StreamInput in) throws IOException { + includes = in.readStringArray(); + if(includes.length == 0) { + includes = null; + } + excludes = in.readStringArray(); + if(excludes.length == 0) { + excludes = null; + } + fetchSource = in.readBoolean(); + } + public SourceFieldsContext(GetRequest request) { if (request.fetchSourceContext() != null) { includes = request.fetchSourceContext().includes(); @@ -107,4 +124,11 @@ public String toString() { return "SourceFieldsContext [includes=" + Arrays.toString(includes) + ", excludes=" + Arrays.toString(excludes) + ", fetchSource=" + fetchSource + "]"; } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeStringArray(Objects.requireNonNullElseGet(includes, () -> new String[]{})); + streamOutput.writeStringArray(Objects.requireNonNullElseGet(excludes, () -> new String[]{})); + streamOutput.writeBoolean(fetchSource); + } } diff --git a/src/main/java/org/opensearch/security/support/StreamableRegistry.java b/src/main/java/org/opensearch/security/support/StreamableRegistry.java new file mode 100644 index 0000000000..54900d2fa0 --- /dev/null +++ b/src/main/java/org/opensearch/security/support/StreamableRegistry.java @@ -0,0 +1,125 @@ +package org.opensearch.security.support; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.HashMap; +import java.util.Map; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; + +import org.opensearch.OpenSearchException; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.io.stream.Writeable; +import org.opensearch.core.common.io.stream.BaseWriteable; + +/** + * Registry for any class that does NOT implement the Writeable interface + * and needs to be serialized over the wire. Supports registration of writer and reader via registerStreamable + * for such classes and provides methods writeTo and readFrom for objects of such registered classes. + *
+ * StreamableRegistry is ThreadLocal singleton, so each thread will have its own instance. + *
+ * Methods are protected and intended to be accessed from only within the package. (mostly by Base64Helper) + */ +public class StreamableRegistry { + + private static final ThreadLocal THREAD_LOCAL = ThreadLocal.withInitial(StreamableRegistry::new); + public final BiMap, Integer> classToIdMap = HashBiMap.create(); + private final Map idToEntryMap = new HashMap<>(); + + private StreamableRegistry() { + registerAllStreamables(); + } + + private static class Entry { + BaseWriteable.Writer writer; + BaseWriteable.Reader reader; + + Entry(BaseWriteable.Writer writer, BaseWriteable.Reader reader) { + this.writer = writer; + this.reader = reader; + } + } + + private BaseWriteable.Writer getWriter(Class clazz) { + if ( !classToIdMap.containsKey(clazz) ) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return idToEntryMap.get(classToIdMap.get(clazz)).writer; + } + + private BaseWriteable.Reader getReader(int id) { + if ( !idToEntryMap.containsKey(id) ) { + throw new OpenSearchException(String.format("No reader registered for id %s", id)); + } + return idToEntryMap.get(id).reader; + } + + private int getId(Class clazz) { + if ( !classToIdMap.containsKey(clazz) ) { + throw new OpenSearchException(String.format("No writer registered for class %s", clazz.getName())); + } + return classToIdMap.get(clazz); + } + + protected boolean isStreamable(Class clazz) { + return classToIdMap.containsKey(clazz); + } + + protected void writeTo(StreamOutput out, Object object) throws IOException { + out.writeByte((byte) getId(object.getClass())); + getWriter(object.getClass()).write(out, object); + } + + protected Object readFrom(StreamInput in) throws IOException { + int id = in.readByte(); + return getReader(id).read(in); + } + + protected static StreamableRegistry getInstance() { + return THREAD_LOCAL.get(); + } + + protected void registerStreamable(Class clazz, BaseWriteable.Writer writer, BaseWriteable.Reader reader) { + Integer id = classToIdMap.size() + 1; + classToIdMap.put(clazz, id); + idToEntryMap.put(id, new Entry(writer, reader)); + } + + protected int getStreamableID(Class clazz) { + if (!isStreamable(clazz)) { + throw new OpenSearchException(String.format("class %s is in streamable registry", clazz.getName())); + } else { + return classToIdMap.get(clazz); + } + } + + /** + * Register all streamables here. Register new streamables towards the end. + * Removing / reordering a registered streamable will change the typeIDs associated with the streamables + * causing a breaking change in the serialization format. + */ + private void registerAllStreamables() { + + // InetSocketAddress + this.registerStreamable( + InetSocketAddress.class, + (Writeable.Writer) (o, v) -> { + final InetSocketAddress inetSocketAddress = (InetSocketAddress) v; + o.writeString(inetSocketAddress.getHostString()); + o.writeByteArray(inetSocketAddress.getAddress().getAddress()); + o.writeInt(inetSocketAddress.getPort()); + }, + (Writeable.Reader) (i) -> { + String host = i.readString(); + byte[] addressBytes = i.readByteArray(); + int port = i.readInt(); + return new InetSocketAddress(InetAddress.getByAddress(host, addressBytes), port); + }) + ; + } + +} diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index 5456d36d9c..932d4d6fc1 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -38,6 +38,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.Version; import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction; import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.opensearch.action.get.GetRequest; @@ -58,6 +59,7 @@ import org.opensearch.security.ssl.transport.SSLConfig; import org.opensearch.security.support.Base64Helper; import org.opensearch.security.support.ConfigConstants; +import org.opensearch.security.support.HeaderHelper; import org.opensearch.security.user.User; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.Transport.Connection; @@ -127,6 +129,8 @@ public void sendRequestDecorate(AsyncSender sender final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); + final boolean useJDKSerialization = connection.getVersion().before(Version.V_3_0_0); + try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { final TransportResponseHandler restoringHandler = new RestoringTransportResponseHandler(handler, stashedContext); getThreadContext().putHeader("_opendistro_security_remotecn", cs.getClusterName().value()); @@ -195,9 +199,15 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION_HEADER, injectedRolesValidationString); } + if(useJDKSerialization) { + Map jdkSerializedHeaders = new HashMap<>(); + HeaderHelper.getAllSerializedHeaderNames().stream().filter(k -> headerMap.get(k) != null).forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.serializeObject(Base64Helper.deserializeObject(headerMap.get(k)), true))); + headerMap.putAll(jdkSerializedHeaders); + } + getThreadContext().putHeader(headerMap); - ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString); + ensureCorrectHeaders(remoteAddress0, user0, origin0, injectedUserString, injectedRolesString, useJDKSerialization); if (isActionTraceEnabled()) { getThreadContext().putHeader("_opendistro_security_trace"+System.currentTimeMillis()+"#"+UUID.randomUUID().toString(), Thread.currentThread().getName()+" IC -> "+action+" "+getThreadContext().getHeaders().entrySet().stream().filter(p->!p.getKey().startsWith("_opendistro_security_trace")).collect(Collectors.toMap(p -> p.getKey(), p -> p.getValue()))); @@ -208,7 +218,8 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL } private void ensureCorrectHeaders(final Object remoteAdr, final User origUser, final String origin, - final String injectedUserString, final String injectedRolesString) { + final String injectedUserString, final String injectedRolesString, + final boolean useJDKSerialization) { // keep original address if(origin != null && !origin.isEmpty() /*&& !Origin.LOCAL.toString().equalsIgnoreCase(origin)*/ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_ORIGIN_HEADER) == null) { @@ -224,7 +235,7 @@ private void ensureCorrectHeaders(final Object remoteAdr, final User origUser, f String remoteAddressHeader = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER); if(remoteAddressHeader == null) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, Base64Helper.serializeObject(((TransportAddress) remoteAdr).address())); + getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER, Base64Helper.serializeObject(((TransportAddress) remoteAdr).address(), useJDKSerialization)); } } @@ -233,7 +244,7 @@ private void ensureCorrectHeaders(final Object remoteAdr, final User origUser, f if(userHeader == null) { if(origUser != null) { - getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser)); + getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(origUser, useJDKSerialization)); } else if(StringUtils.isNotEmpty(injectedRolesString)) { getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_HEADER, injectedRolesString); diff --git a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java index 4a2919fdb2..b2903cec29 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java +++ b/src/main/java/org/opensearch/security/transport/SecurityRequestHandler.java @@ -37,6 +37,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchSecurityException; +import org.opensearch.Version; import org.opensearch.action.bulk.BulkShardRequest; import org.opensearch.action.support.replication.TransportReplicationAction.ConcreteShardRequest; import org.opensearch.cluster.service.ClusterService; @@ -102,6 +103,10 @@ protected void messageReceivedDecorate(final T request, final TransportRequestHa resolvedActionClass = ((ConcreteShardRequest) request).getRequest().getClass().getSimpleName(); } + final boolean useJDKSerialization = transportChannel.getVersion().before(Version.V_3_0_0); + + getThreadContext().putTransient(ConfigConstants.USE_JDK_SERIALIZATION, useJDKSerialization); + String initialActionClassValue = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INITIAL_ACTION_CLASS_HEADER); final ThreadContext.StoredContext sgContext = getThreadContext().newStoredContext(false); @@ -151,13 +156,13 @@ else if(!Strings.isNullOrEmpty(injectedUserHeader)) { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserHeader); } } else { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader))); + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader, useJDKSerialization))); } final String originalRemoteAddress = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER); if(!Strings.isNullOrEmpty(originalRemoteAddress)) { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress))); + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress, useJDKSerialization))); } final String rolesValidation = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROLES_VALIDATION_HEADER); @@ -248,13 +253,13 @@ else if(!Strings.isNullOrEmpty(injectedUserHeader)) { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_USER, injectedUserHeader); } } else { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader))); + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_USER, Objects.requireNonNull((User) Base64Helper.deserializeObject(userHeader, useJDKSerialization))); } String originalRemoteAddress = getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS_HEADER); if(!Strings.isNullOrEmpty(originalRemoteAddress)) { - getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress))); + getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, new TransportAddress((InetSocketAddress) Base64Helper.deserializeObject(originalRemoteAddress, useJDKSerialization))); } else { getThreadContext().putTransient(ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS, request.remoteAddress()); } diff --git a/src/main/java/org/opensearch/security/user/User.java b/src/main/java/org/opensearch/security/user/User.java index 86089afd35..3eb0873648 100644 --- a/src/main/java/org/opensearch/security/user/User.java +++ b/src/main/java/org/opensearch/security/user/User.java @@ -238,7 +238,7 @@ public final void copyRolesFrom(final User user) { public void writeTo(StreamOutput out) throws IOException { out.writeString(name); out.writeStringCollection(new ArrayList(roles)); - out.writeString(requestedTenant); + out.writeString(requestedTenant == null ? "" : requestedTenant); out.writeMap(attributes, StreamOutput::writeString, StreamOutput::writeString); out.writeStringCollection(securityRoles ==null?Collections.emptyList():new ArrayList(securityRoles)); } diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 81c2505985..99b92e87bd 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -16,14 +16,18 @@ import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.HashMap; -import java.util.regex.Pattern; import com.google.common.io.BaseEncoding; import org.junit.Assert; import org.junit.Test; +import org.ldaptive.LdapEntry; + +import com.amazon.dlic.auth.ldap.LdapUser; import org.opensearch.OpenSearchException; import org.opensearch.action.search.SearchRequest; +import org.opensearch.security.auth.UserInjector; +import org.opensearch.security.user.AuthCredentials; import org.opensearch.security.user.User; import static org.opensearch.security.support.Base64Helper.deserializeObject; @@ -35,6 +39,10 @@ private static final class NotSafeSerializable implements Serializable { private static final long serialVersionUID = 5135559266828470092L; } + private static Serializable dsJDK(Serializable s) { + return deserializeObject(serializeObject(s, true), true); + } + private static Serializable ds(Serializable s) { return deserializeObject(serializeObject(s)); } @@ -43,54 +51,56 @@ private static Serializable ds(Serializable s) { public void testString() { String string = "string"; Assert.assertEquals(string, ds(string)); + Assert.assertEquals(string, dsJDK(string)); } @Test public void testInteger() { Integer integer = Integer.valueOf(0); Assert.assertEquals(integer, ds(integer)); + Assert.assertEquals(integer, dsJDK(integer)); } @Test public void testDouble() { Double number = Double.valueOf(0.); Assert.assertEquals(number, ds(number)); + Assert.assertEquals(number, dsJDK(number)); } @Test public void testInetSocketAddress() { InetSocketAddress inetSocketAddress = new InetSocketAddress(0); Assert.assertEquals(inetSocketAddress, ds(inetSocketAddress)); - } - - @Test - public void testPattern() { - Pattern pattern = Pattern.compile(".*"); - Assert.assertEquals(pattern.pattern(), ((Pattern) ds(pattern)).pattern()); + Assert.assertEquals(inetSocketAddress, dsJDK(inetSocketAddress)); } @Test public void testUser() { User user = new User("user"); Assert.assertEquals(user, ds(user)); + Assert.assertEquals(user, dsJDK(user)); } @Test public void testSourceFieldsContext() { SourceFieldsContext sourceFieldsContext = new SourceFieldsContext(new SearchRequest("")); Assert.assertEquals(sourceFieldsContext.toString(), ds(sourceFieldsContext).toString()); + Assert.assertEquals(sourceFieldsContext.toString(), dsJDK(sourceFieldsContext).toString()); } @Test public void testHashMap() { HashMap map = new HashMap(); Assert.assertEquals(map, ds(map)); + Assert.assertEquals(map, dsJDK(map)); } @Test public void testArrayList() { ArrayList list = new ArrayList(); Assert.assertEquals(list, ds(list)); + Assert.assertEquals(list, dsJDK(list)); } @Test(expected = OpenSearchException.class) @@ -106,4 +116,45 @@ public void notSafeDeserializable() throws Exception { } deserializeObject(BaseEncoding.base64().encode(bos.toByteArray())); } + + @Test + public void testLdapUser() { + LdapUser ldapUser = new LdapUser( + "username", + "originalusername", + new LdapEntry("dn"), + new AuthCredentials("originalusername", "12345"), + 34, + WildcardMatcher.ANY + ); + Assert.assertEquals(ldapUser, ds(ldapUser)); + Assert.assertEquals(ldapUser, dsJDK(ldapUser)); + } + + @Test + public void testGetWriteableClassID() { + // a need to make a change in this test signifies a breaking change in security plugin's custom serialization + // format + Assert.assertEquals(Integer.valueOf(1), Base64Helper.getWriteableClassID(User.class)); + Assert.assertEquals(Integer.valueOf(2), Base64Helper.getWriteableClassID(LdapUser.class)); + Assert.assertEquals(Integer.valueOf(3), Base64Helper.getWriteableClassID(UserInjector.InjectedUser.class)); + Assert.assertEquals(Integer.valueOf(4), Base64Helper.getWriteableClassID(SourceFieldsContext.class)); + } + + @Test + public void testInjectedUser() { + UserInjector.InjectedUser injectedUser = new UserInjector.InjectedUser("username"); + + // we expect to get User object when deserializing InjectedUser via JDK serialization + User user = new User("username"); + User deserializedUser = (User) dsJDK(injectedUser); + Assert.assertEquals(user, deserializedUser); + Assert.assertTrue(deserializedUser.isInjected()); + + // for custom serialization, we expect InjectedUser to be returned on deserialization + UserInjector.InjectedUser deserializedInjecteduser = (UserInjector.InjectedUser) ds(injectedUser); + Assert.assertEquals(injectedUser, deserializedInjecteduser); + Assert.assertTrue(deserializedInjecteduser.isInjected()); + } + } diff --git a/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java new file mode 100644 index 0000000000..063a439143 --- /dev/null +++ b/src/test/java/org/opensearch/security/support/StreamableRegistryTest.java @@ -0,0 +1,15 @@ +package org.opensearch.security.support; + +import org.junit.Assert; +import org.junit.Test; + +import java.net.InetSocketAddress; + +public class StreamableRegistryTest { + + StreamableRegistry streamableRegistry = StreamableRegistry.getInstance(); + @Test + public void testStreamableTypeIDs() { + Assert.assertEquals(1, streamableRegistry.getStreamableID(InetSocketAddress.class)); + } +}