diff --git a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java index 1b6cd16fcdca..676e1823c828 100644 --- a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java +++ b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchClient.java @@ -31,11 +31,13 @@ import io.prestosql.spi.type.RowType; import io.prestosql.spi.type.RowType.Field; import io.prestosql.spi.type.Type; +import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsRequest; import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsRequest; import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.cluster.metadata.MappingMetaData; +import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.collect.ImmutableOpenMap; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; @@ -179,16 +181,25 @@ public ElasticsearchTableDescription getTable(String schemaName, String tableNam Optional.of(buildColumns(table))); } - public ClusterSearchShardsResponse getSearchShards(String index) + public List getSearchShards(String index) { try { - return retry() + ClusterSearchShardsResponse result = retry() .maxAttempts(maxAttempts) .exponentialBackoff(maxRetryTime) .run("getSearchShardsResponse", () -> client.admin() .cluster() .searchShards(new ClusterSearchShardsRequest(index)) .actionGet(requestTimeout.toMillis())); + + ImmutableList.Builder shards = ImmutableList.builder(); + DiscoveryNode[] nodes = result.getNodes(); + for (ClusterSearchShardsGroup group : result.getGroups()) { + int nodeIndex = group.getShardId().getId() % nodes.length; + shards.add(new Shard(group.getShardId().getId(), nodes[nodeIndex].getHostName(), nodes[nodeIndex].getAddress().getPort())); + } + + return shards.build(); } catch (Exception e) { throw new RuntimeException(e); diff --git a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchSplitManager.java b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchSplitManager.java index 811407be7005..4a1a417e6d83 100644 --- a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchSplitManager.java +++ b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/ElasticsearchSplitManager.java @@ -13,21 +13,17 @@ */ package io.prestosql.elasticsearch; -import com.google.common.collect.ImmutableList; import io.prestosql.spi.connector.ConnectorSession; -import io.prestosql.spi.connector.ConnectorSplit; import io.prestosql.spi.connector.ConnectorSplitManager; import io.prestosql.spi.connector.ConnectorSplitSource; import io.prestosql.spi.connector.ConnectorTableHandle; import io.prestosql.spi.connector.ConnectorTransactionHandle; import io.prestosql.spi.connector.FixedSplitSource; -import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsGroup; -import org.elasticsearch.action.admin.cluster.shards.ClusterSearchShardsResponse; -import org.elasticsearch.cluster.node.DiscoveryNode; import javax.inject.Inject; import static com.google.common.base.Verify.verifyNotNull; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.Objects.requireNonNull; public class ElasticsearchSplitManager @@ -48,20 +44,15 @@ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHand ElasticsearchTableDescription tableDescription = client.getTable(tableHandle.getSchemaName(), tableHandle.getTableName()); verifyNotNull(table, "Table no longer exists: %s", tableHandle.toString()); - ImmutableList.Builder splits = ImmutableList.builder(); - String index = tableDescription.getIndex(); - ClusterSearchShardsResponse response = client.getSearchShards(index); - DiscoveryNode[] nodes = response.getNodes(); - for (ClusterSearchShardsGroup group : response.getGroups()) { - int nodeIndex = group.getShardId().getId() % nodes.length; - ElasticsearchSplit split = new ElasticsearchSplit( - index, - tableDescription.getType(), - group.getShardId().getId(), - nodes[nodeIndex].getHostName(), - nodes[nodeIndex].getAddress().getPort()); - splits.add(split); - } - return new FixedSplitSource(splits.build()); + List splits = client.getSearchShards(tableDescription.getIndex()).stream() + .map(shard -> new ElasticsearchSplit( + tableDescription.getIndex(), + tableDescription.getType(), + shard.getId(), + shard.getHost(), + shard.getPort())) + .collect(toImmutableList()); + + return new FixedSplitSource(splits); } } diff --git a/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/Shard.java b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/Shard.java new file mode 100644 index 000000000000..7b2fc5df75c7 --- /dev/null +++ b/presto-elasticsearch/src/main/java/io/prestosql/elasticsearch/Shard.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.elasticsearch; + +import static java.util.Objects.requireNonNull; + +public class Shard +{ + private final int id; + private final String host; + private final int port; + + public Shard(int id, String host, int port) + { + this.id = id; + this.host = requireNonNull(host, "host is null"); + this.port = port; + } + + public int getId() + { + return id; + } + + public String getHost() + { + return host; + } + + public int getPort() + { + return port; + } +}