Skip to content

Commit

Permalink
Use stable node to bucket mapping for Hive and Iceberg
Browse files Browse the repository at this point in the history
This improves cache hit rate for file system caching
  • Loading branch information
dain committed Dec 5, 2024
1 parent 0347376 commit 3b98c61
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/
package io.trino.metadata;

import io.airlift.slice.XxHash64;
import io.trino.client.NodeVersion;
import io.trino.spi.HostAddress;
import io.trino.spi.Node;
Expand All @@ -27,6 +28,7 @@
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Strings.nullToEmpty;
import static io.airlift.node.AddressToHostname.tryDecodeHostnameToAddress;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

/**
Expand All @@ -39,6 +41,7 @@ public class InternalNode
private final URI internalUri;
private final NodeVersion nodeVersion;
private final boolean coordinator;
private final long longHashCode;

public InternalNode(String nodeIdentifier, URI internalUri, NodeVersion nodeVersion, boolean coordinator)
{
Expand All @@ -47,6 +50,11 @@ public InternalNode(String nodeIdentifier, URI internalUri, NodeVersion nodeVers
this.internalUri = requireNonNull(internalUri, "internalUri is null");
this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null");
this.coordinator = coordinator;
this.longHashCode = new XxHash64(coordinator ? 1 : 0)
.update(nodeIdentifier.getBytes(UTF_8))
.update(internalUri.toString().getBytes(UTF_8))
.update(nodeVersion.getVersion().getBytes(UTF_8))
.hash();
}

@Override
Expand Down Expand Up @@ -115,10 +123,15 @@ public boolean equals(Object obj)
Objects.equals(nodeVersion, o.nodeVersion);
}

public long longHashCode()
{
return longHashCode;
}

@Override
public int hashCode()
{
return Objects.hash(nodeIdentifier, internalUri, nodeVersion, coordinator);
return (int) longHashCode;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableList;
import com.google.inject.Inject;
import io.airlift.slice.XxHash64;
import io.trino.Session;
import io.trino.connector.CatalogServiceProvider;
import io.trino.execution.scheduler.BucketNodeMap;
Expand All @@ -37,17 +38,14 @@
import io.trino.split.EmptySplit;
import io.trino.sql.planner.SystemPartitioningHandle.SystemPartitioning;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -192,7 +190,7 @@ private NodePartitionMap getNodePartitioningMap(
CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle);
bucketToNode = bucketToNodeCache.computeIfAbsent(
connectorBucketNodeMap.getBucketCount(),
bucketCount -> createArbitraryBucketToNode(getAllNodes(session, catalogHandle), bucketCount));
bucketCount -> createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), getAllNodes(session, catalogHandle), bucketCount));
}
}

Expand Down Expand Up @@ -250,8 +248,9 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit
return new BucketNodeMap(splitToBucket, getFixedMapping(bucketNodeMap.get()));
}

long seed = bucketNodeMap.map(ConnectorBucketNodeMap::getCacheKeyHint).orElse(ThreadLocalRandom.current().nextLong());
List<InternalNode> nodes = getAllNodes(session, requiredCatalogHandle(partitioningHandle));
return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(nodes, bucketCount));
return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(seed, nodes, bucketCount));
}

/**
Expand Down Expand Up @@ -350,17 +349,31 @@ private static CatalogHandle requiredCatalogHandle(PartitioningHandle partitioni
new IllegalStateException("No catalog handle for partitioning handle: " + partitioningHandle));
}

private static List<InternalNode> createArbitraryBucketToNode(List<InternalNode> nodes, int bucketCount)
private static List<InternalNode> createArbitraryBucketToNode(long seed, List<InternalNode> nodes, int bucketCount)
{
return cyclingShuffledStream(nodes)
.limit(bucketCount)
.collect(toImmutableList());
}
requireNonNull(nodes, "nodes is null");
checkArgument(!nodes.isEmpty(), "nodes is empty");
checkArgument(bucketCount > 0, "bucketCount must be greater than zero");

// Assign each bucket to the machine with the highest weight (hash)
// This is simple Rendezvous Hashing (Highest Random Weight) algorithm
ImmutableList.Builder<InternalNode> bucketAssignments = ImmutableList.builderWithExpectedSize(bucketCount);
for (int bucket = 0; bucket < bucketCount; bucket++) {
long bucketHash = XxHash64.hash(seed, bucket);

InternalNode bestNode = null;
long highestWeight = Long.MIN_VALUE;
for (InternalNode node : nodes) {
long weight = XxHash64.hash(node.longHashCode(), bucketHash);
if (weight >= highestWeight) {
highestWeight = weight;
bestNode = node;
}
}

private static <T> Stream<T> cyclingShuffledStream(Collection<T> collection)
{
List<T> list = new ArrayList<>(collection);
Collections.shuffle(list);
return Stream.generate(() -> list).flatMap(List::stream);
bucketAssignments.add(requireNonNull(bestNode));
}

return bucketAssignments.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@

import java.util.List;
import java.util.Optional;
import java.util.concurrent.ThreadLocalRandom;

import static java.lang.String.format;

public final class ConnectorBucketNodeMap
{
private final int bucketCount;
private final Optional<List<Node>> bucketToNode;
private final long cacheKeyHint;

public static ConnectorBucketNodeMap createBucketNodeMap(int bucketCount)
{
return new ConnectorBucketNodeMap(bucketCount, Optional.empty());
return new ConnectorBucketNodeMap(bucketCount, Optional.empty(), ThreadLocalRandom.current().nextLong());
}

public static ConnectorBucketNodeMap createBucketNodeMap(List<Node> bucketToNode)
{
return new ConnectorBucketNodeMap(bucketToNode.size(), Optional.of(bucketToNode));
return new ConnectorBucketNodeMap(bucketToNode.size(), Optional.of(bucketToNode), ThreadLocalRandom.current().nextLong());
}

private ConnectorBucketNodeMap(int bucketCount, Optional<List<Node>> bucketToNode)
private ConnectorBucketNodeMap(int bucketCount, Optional<List<Node>> bucketToNode, long cacheKeyHint)
{
if (bucketCount <= 0) {
throw new IllegalArgumentException("bucketCount must be positive");
Expand All @@ -45,6 +47,7 @@ private ConnectorBucketNodeMap(int bucketCount, Optional<List<Node>> bucketToNod
}
this.bucketCount = bucketCount;
this.bucketToNode = bucketToNode.map(List::copyOf);
this.cacheKeyHint = cacheKeyHint;
}

public int getBucketCount()
Expand All @@ -61,4 +64,14 @@ public List<Node> getFixedMapping()
{
return bucketToNode.orElseThrow(() -> new IllegalArgumentException("No fixed bucket to node mapping"));
}

public long getCacheKeyHint()
{
return cacheKeyHint;
}

public ConnectorBucketNodeMap withCacheKeyHint(long cacheKeyHint)
{
return new ConnectorBucketNodeMap(bucketCount, bucketToNode, cacheKeyHint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public Optional<ConnectorBucketNodeMap> getBucketNodeMapping(ConnectorTransactio
{
HivePartitioningHandle handle = (HivePartitioningHandle) partitioningHandle;
if (!handle.isUsePartitionedBucketing()) {
return Optional.of(createBucketNodeMap(handle.getBucketCount()));
return Optional.of(createBucketNodeMap(handle.getBucketCount()).withCacheKeyHint(handle.getCacheKeyHint()));
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import io.trino.metastore.HiveType;
import io.trino.plugin.hive.util.HiveBucketing.BucketingVersion;
import io.trino.spi.connector.ConnectorPartitioningHandle;
Expand Down Expand Up @@ -72,6 +74,17 @@ public boolean isUsePartitionedBucketing()
return usePartitionedBucketing;
}

public long getCacheKeyHint()
{
Hasher hasher = Hashing.goodFastHash(64).newHasher();
hasher.putInt(bucketingVersion.getVersion());
hasher.putInt(bucketCount);
for (HiveType hiveType : hiveTypes) {
hasher.putString(hiveType.toString(), java.nio.charset.StandardCharsets.UTF_8);
}
return hasher.hash().asLong();
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public Optional<ConnectorBucketNodeMap> getBucketNodeMapping(
// when there is a single bucket partition function, inform the engine there is a limit on the number of buckets
// TODO: when there are multiple bucket partition functions, we could compute the product of bucket counts, but this causes the engine to create too many writers
if (partitionFunctions.size() == 1 && partitionFunctions.getFirst().transform() == BUCKET) {
return Optional.of(createBucketNodeMap(partitionFunctions.getFirst().size().orElseThrow()));
return Optional.of(createBucketNodeMap(partitionFunctions.getFirst().size().orElseThrow()).withCacheKeyHint(handle.getCacheKeyHint()));
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
package io.trino.plugin.iceberg;

import com.google.common.collect.ImmutableList;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import io.trino.spi.connector.ConnectorPartitioningHandle;
import io.trino.spi.type.TypeManager;
import org.apache.iceberg.PartitionField;
Expand All @@ -29,6 +31,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static io.trino.plugin.iceberg.TypeConverter.toTrinoType;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

public record IcebergPartitioningHandle(boolean update, List<IcebergPartitionFunction> partitionFunctions)
Expand Down Expand Up @@ -100,4 +103,17 @@ private static boolean buildDataPaths(Set<Integer> partitionFieldIds, Types.Stru
}
return hasPartitionFields;
}

public long getCacheKeyHint()
{
Hasher hasher = Hashing.goodFastHash(64).newHasher();
hasher.putBoolean(update);
for (IcebergPartitionFunction function : partitionFunctions) {
hasher.putInt(function.transform().ordinal());
function.dataPath().forEach(hasher::putInt);
hasher.putString(function.type().getTypeSignature().toString(), UTF_8);
function.size().ifPresent(hasher::putInt);
}
return hasher.hash().asLong();
}
}

0 comments on commit 3b98c61

Please sign in to comment.