diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index dc7c6f0b0d..aec2d1b0b6 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -10,6 +10,7 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; +import org.opensearch.common.util.CollectionUtils; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; @@ -184,7 +185,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (actions != null) { builder.field(CONNECTOR_ACTIONS_FIELD, actions); } - if (backendRoles != null) { + if (!CollectionUtils.isEmpty(backendRoles)) { builder.field(BACKEND_ROLES_FIELD, backendRoles); } if (addAllBackendRoles != null) { @@ -224,7 +225,7 @@ public void writeTo(StreamOutput output) throws IOException { } else { output.writeBoolean(false); } - if (backendRoles != null) { + if (!CollectionUtils.isEmpty(backendRoles)) { output.writeBoolean(true); output.writeOptionalStringCollection(backendRoles); } else { @@ -267,4 +268,4 @@ public MLCreateConnectorInput(StreamInput input) throws IOException { } dryRun = input.readBoolean(); } -} \ No newline at end of file +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java index 4b1617e7af..782bd9501e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java @@ -19,6 +19,7 @@ import org.apache.http.protocol.HttpContext; import org.apache.logging.log4j.util.Strings; +import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.Arrays; @@ -42,10 +43,7 @@ public int resolve(HttpHost host) throws UnsupportedSchemeException { } }); - builder.setDnsResolver(hostName -> { - validateIp(hostName); - return InetAddress.getAllByName(hostName); - }); + builder.setDnsResolver(MLHttpClientFactory::validateIp); builder.setRedirectStrategy(new LaxRedirectStrategy() { @Override @@ -79,15 +77,51 @@ protected static void validateSchemaAndPort(HttpHost host) { } } - protected static void validateIp(String hostName) throws UnknownHostException { + protected static InetAddress[] validateIp(String hostName) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); if (hasPrivateIpAddress(addresses)) { log.error("Remote inference host name has private ip address: " + hostName); throw new IllegalArgumentException(hostName); } + return addresses; } private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { + for (InetAddress ip : ipAddress) { + if (ip instanceof Inet4Address) { + byte[] bytes = ip.getAddress(); + if (bytes.length != 4) { + return true; + } else { + int firstOctets = bytes[0] & 0xff; + int firstInOctal = parseWithOctal(String.valueOf(firstOctets)); + int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16); + if (firstInOctal == 127 || firstInHex == 127) { + return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1; + } else if (firstInOctal == 10 || firstInHex == 10) { + return true; + } else if (firstInOctal == 172 || firstInHex == 172) { + int secondOctets = bytes[1] & 0xff; + int secondInOctal = parseWithOctal(String.valueOf(secondOctets)); + int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16); + return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32); + } else if (firstInOctal == 192 || firstInHex == 192) { + int secondOctets = bytes[1] & 0xff; + int secondInOctal = parseWithOctal(String.valueOf(secondOctets)); + int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16); + return secondInOctal == 168 || secondInHex == 168; + } + } + } + } return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress()); } + + private static int parseWithOctal(String input) { + try { + return Integer.parseInt(input, 8); + } catch (NumberFormatException e) { + return Integer.parseInt(input); + } + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java index 5a1b94c06c..4fbf9888fd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java @@ -11,9 +11,11 @@ import org.junit.Test; import org.junit.rules.ExpectedException; +import java.net.InetAddress; import java.net.UnknownHostException; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; public class MLHttpClientFactoryTests { @@ -43,6 +45,45 @@ public void test_validateIp_privateIp_throwException() throws UnknownHostExcepti MLHttpClientFactory.validateIp("localhost"); } + @Test + public void test_validateIp_rarePrivateIp_throwException() throws UnknownHostException { + try { + MLHttpClientFactory.validateIp("0254.020.00.01"); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validateIp("172.1048577"); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validateIp("2886729729"); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validateIp("192.11010049"); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validateIp("3232300545"); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + + try { + MLHttpClientFactory.validateIp("0:0:0:0:0:ffff:127.0.0.1"); + } catch (IllegalArgumentException e) { + assertNotNull(e); + } + } + @Test public void test_validateSchemaAndPort_success() { HttpHost httpHost = new HttpHost("api.openai.com", 8080, "https"); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java index a0bf60788e..0408fafe96 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/TransportCreateConnectorAction.java @@ -193,7 +193,7 @@ private void validateRequest4AccessControl(MLCreateConnectorInput input, User us } private void validateSecurityDisabledOrConnectorAccessControlDisabled(MLCreateConnectorInput input) { - if (input.getAccess() != null || input.getAddAllBackendRoles() != null || input.getBackendRoles() != null) { + if (input.getAccess() != null || input.getAddAllBackendRoles() != null || !CollectionUtils.isEmpty(input.getBackendRoles())) { throw new IllegalArgumentException( "You cannot specify connector access control parameters because the Security plugin or connector access control is disabled on your cluster." );