diff --git a/core/trino-main/src/main/java/io/trino/metadata/InternalNode.java b/core/trino-main/src/main/java/io/trino/metadata/InternalNode.java index 6ac20ed2ea37..054ebf6e0c15 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/InternalNode.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InternalNode.java @@ -13,6 +13,7 @@ */ package io.trino.metadata; +import io.airlift.slice.XxHash64; import io.trino.client.NodeVersion; import io.trino.spi.HostAddress; import io.trino.spi.Node; @@ -27,6 +28,7 @@ import static com.google.common.base.Strings.emptyToNull; import static com.google.common.base.Strings.nullToEmpty; import static io.airlift.node.AddressToHostname.tryDecodeHostnameToAddress; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; /** @@ -39,6 +41,7 @@ public class InternalNode private final URI internalUri; private final NodeVersion nodeVersion; private final boolean coordinator; + private final long longHashCode; public InternalNode(String nodeIdentifier, URI internalUri, NodeVersion nodeVersion, boolean coordinator) { @@ -47,6 +50,11 @@ public InternalNode(String nodeIdentifier, URI internalUri, NodeVersion nodeVers this.internalUri = requireNonNull(internalUri, "internalUri is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); this.coordinator = coordinator; + this.longHashCode = new XxHash64(coordinator ? 1 : 0) + .update(nodeIdentifier.getBytes(UTF_8)) + .update(internalUri.toString().getBytes(UTF_8)) + .update(nodeVersion.getVersion().getBytes(UTF_8)) + .hash(); } @Override @@ -115,10 +123,15 @@ public boolean equals(Object obj) Objects.equals(nodeVersion, o.nodeVersion); } + public long longHashCode() + { + return longHashCode; + } + @Override public int hashCode() { - return Objects.hash(nodeIdentifier, internalUri, nodeVersion, coordinator); + return (int) longHashCode; } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java index 64e8b53d3909..c668e2bcee75 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java @@ -17,6 +17,7 @@ import com.google.common.collect.HashBiMap; import com.google.common.collect.ImmutableList; import com.google.inject.Inject; +import io.airlift.slice.XxHash64; import io.trino.Session; import io.trino.connector.CatalogServiceProvider; import io.trino.execution.scheduler.BucketNodeMap; @@ -37,17 +38,14 @@ import io.trino.split.EmptySplit; import io.trino.sql.planner.SystemPartitioningHandle.SystemPartitioning; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicReference; import java.util.function.ToIntFunction; import java.util.stream.IntStream; -import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; @@ -192,7 +190,7 @@ private NodePartitionMap getNodePartitioningMap( CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); bucketToNode = bucketToNodeCache.computeIfAbsent( connectorBucketNodeMap.getBucketCount(), - bucketCount -> createArbitraryBucketToNode(getAllNodes(session, catalogHandle), bucketCount)); + bucketCount -> createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), getAllNodes(session, catalogHandle), bucketCount)); } } @@ -250,8 +248,9 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit return new BucketNodeMap(splitToBucket, getFixedMapping(bucketNodeMap.get())); } + long seed = bucketNodeMap.map(ConnectorBucketNodeMap::getCacheKeyHint).orElse(ThreadLocalRandom.current().nextLong()); List nodes = getAllNodes(session, requiredCatalogHandle(partitioningHandle)); - return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount)); + return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(seed, nodes, bucketCount)); } /** @@ -350,17 +349,31 @@ private static CatalogHandle requiredCatalogHandle(PartitioningHandle partitioni new IllegalStateException("No catalog handle for partitioning handle: " + partitioningHandle)); } - private static List createArbitraryBucketToNode(List nodes, int bucketCount) + private static List createArbitraryBucketToNode(long seed, List nodes, int bucketCount) { - return cyclingShuffledStream(nodes) - .limit(bucketCount) - .collect(toImmutableList()); - } + requireNonNull(nodes, "nodes is null"); + checkArgument(!nodes.isEmpty(), "nodes is empty"); + checkArgument(bucketCount > 0, "bucketCount must be greater than zero"); + + // Assign each bucket to the machine with the highest weight (hash) + // This is simple Rendezvous Hashing (Highest Random Weight) algorithm + ImmutableList.Builder bucketAssignments = ImmutableList.builderWithExpectedSize(bucketCount); + for (int bucket = 0; bucket < bucketCount; bucket++) { + long bucketHash = XxHash64.hash(seed, bucket); + + InternalNode bestNode = null; + long highestWeight = Long.MIN_VALUE; + for (InternalNode node : nodes) { + long weight = XxHash64.hash(node.longHashCode(), bucketHash); + if (weight >= highestWeight) { + highestWeight = weight; + bestNode = node; + } + } - private static Stream cyclingShuffledStream(Collection collection) - { - List list = new ArrayList<>(collection); - Collections.shuffle(list); - return Stream.generate(() -> list).flatMap(List::stream); + bucketAssignments.add(requireNonNull(bestNode)); + } + + return bucketAssignments.build(); } } diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorBucketNodeMap.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorBucketNodeMap.java index b8228f732f49..30a9cf0c0e75 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorBucketNodeMap.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorBucketNodeMap.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; import static java.lang.String.format; @@ -24,18 +25,19 @@ public final class ConnectorBucketNodeMap { private final int bucketCount; private final Optional> bucketToNode; + private final long cacheKeyHint; public static ConnectorBucketNodeMap createBucketNodeMap(int bucketCount) { - return new ConnectorBucketNodeMap(bucketCount, Optional.empty()); + return new ConnectorBucketNodeMap(bucketCount, Optional.empty(), ThreadLocalRandom.current().nextLong()); } public static ConnectorBucketNodeMap createBucketNodeMap(List bucketToNode) { - return new ConnectorBucketNodeMap(bucketToNode.size(), Optional.of(bucketToNode)); + return new ConnectorBucketNodeMap(bucketToNode.size(), Optional.of(bucketToNode), ThreadLocalRandom.current().nextLong()); } - private ConnectorBucketNodeMap(int bucketCount, Optional> bucketToNode) + private ConnectorBucketNodeMap(int bucketCount, Optional> bucketToNode, long cacheKeyHint) { if (bucketCount <= 0) { throw new IllegalArgumentException("bucketCount must be positive"); @@ -45,6 +47,7 @@ private ConnectorBucketNodeMap(int bucketCount, Optional> bucketToNod } this.bucketCount = bucketCount; this.bucketToNode = bucketToNode.map(List::copyOf); + this.cacheKeyHint = cacheKeyHint; } public int getBucketCount() @@ -61,4 +64,14 @@ public List getFixedMapping() { return bucketToNode.orElseThrow(() -> new IllegalArgumentException("No fixed bucket to node mapping")); } + + public long getCacheKeyHint() + { + return cacheKeyHint; + } + + public ConnectorBucketNodeMap withCacheKeyHint(long cacheKeyHint) + { + return new ConnectorBucketNodeMap(bucketCount, bucketToNode, cacheKeyHint); + } } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java index e0b8bb734cf3..5406e73ef6b0 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveNodePartitioningProvider.java @@ -73,7 +73,7 @@ public Optional getBucketNodeMapping(ConnectorTransactio { HivePartitioningHandle handle = (HivePartitioningHandle) partitioningHandle; if (!handle.isUsePartitionedBucketing()) { - return Optional.of(createBucketNodeMap(handle.getBucketCount())); + return Optional.of(createBucketNodeMap(handle.getBucketCount()).withCacheKeyHint(handle.getCacheKeyHint())); } return Optional.empty(); } diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePartitioningHandle.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePartitioningHandle.java index 0f241e652960..1de57aa619b4 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePartitioningHandle.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HivePartitioningHandle.java @@ -15,6 +15,8 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; import io.trino.metastore.HiveType; import io.trino.plugin.hive.util.HiveBucketing.BucketingVersion; import io.trino.spi.connector.ConnectorPartitioningHandle; @@ -72,6 +74,17 @@ public boolean isUsePartitionedBucketing() return usePartitionedBucketing; } + public long getCacheKeyHint() + { + Hasher hasher = Hashing.goodFastHash(64).newHasher(); + hasher.putInt(bucketingVersion.getVersion()); + hasher.putInt(bucketCount); + for (HiveType hiveType : hiveTypes) { + hasher.putString(hiveType.toString(), java.nio.charset.StandardCharsets.UTF_8); + } + return hasher.hash().asLong(); + } + @Override public String toString() { diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java index c7ddabff0e2c..eb526d2bf3ad 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergNodePartitioningProvider.java @@ -55,7 +55,7 @@ public Optional getBucketNodeMapping( // when there is a single bucket partition function, inform the engine there is a limit on the number of buckets // TODO: when there are multiple bucket partition functions, we could compute the product of bucket counts, but this causes the engine to create too many writers if (partitionFunctions.size() == 1 && partitionFunctions.getFirst().transform() == BUCKET) { - return Optional.of(createBucketNodeMap(partitionFunctions.getFirst().size().orElseThrow())); + return Optional.of(createBucketNodeMap(partitionFunctions.getFirst().size().orElseThrow()).withCacheKeyHint(handle.getCacheKeyHint())); } return Optional.empty(); } diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPartitioningHandle.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPartitioningHandle.java index 4bf7bce0d9f3..b7b486518602 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPartitioningHandle.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergPartitioningHandle.java @@ -14,6 +14,8 @@ package io.trino.plugin.iceberg; import com.google.common.collect.ImmutableList; +import com.google.common.hash.Hasher; +import com.google.common.hash.Hashing; import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.type.TypeManager; import org.apache.iceberg.PartitionField; @@ -29,6 +31,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.plugin.iceberg.TypeConverter.toTrinoType; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Objects.requireNonNull; public record IcebergPartitioningHandle(boolean update, List partitionFunctions) @@ -100,4 +103,17 @@ private static boolean buildDataPaths(Set partitionFieldIds, Types.Stru } return hasPartitionFields; } + + public long getCacheKeyHint() + { + Hasher hasher = Hashing.goodFastHash(64).newHasher(); + hasher.putBoolean(update); + for (IcebergPartitionFunction function : partitionFunctions) { + hasher.putInt(function.transform().ordinal()); + function.dataPath().forEach(hasher::putInt); + hasher.putString(function.type().getTypeSignature().toString(), UTF_8); + function.size().ifPresent(hasher::putInt); + } + return hasher.hash().asLong(); + } }