Skip to content

Commit

Permalink
Use bulk merging of metrics in PlanNodeStatsSummarizer
Browse files Browse the repository at this point in the history
Reduces memory allocations from merging of TDigestHistogram
  • Loading branch information
raunaqmorarka committed Oct 16, 2023
1 parent 91ee252 commit 382ec92
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import io.trino.spi.metrics.Metrics;

import java.util.List;

import static java.util.Objects.requireNonNull;

class BasicOperatorStats
Expand Down Expand Up @@ -73,4 +75,21 @@ public static BasicOperatorStats merge(BasicOperatorStats first, BasicOperatorSt
first.metrics.mergeWith(second.metrics),
first.connectorMetrics.mergeWith(second.connectorMetrics));
}

public static BasicOperatorStats merge(List<BasicOperatorStats> operatorStats)
{
long totalDrivers = 0;
long inputPositions = 0;
double sumSquaredInputPositions = 0;
Metrics.Accumulator metricsAccumulator = Metrics.accumulator();
Metrics.Accumulator connectorMetricsAccumulator = Metrics.accumulator();
for (BasicOperatorStats stats : operatorStats) {
totalDrivers += stats.totalDrivers;
inputPositions += stats.inputPositions;
sumSquaredInputPositions += stats.sumSquaredInputPositions;
metricsAccumulator.add(stats.metrics);
connectorMetricsAccumulator.add(stats.connectorMetrics);
}
return new BasicOperatorStats(totalDrivers, inputPositions, sumSquaredInputPositions, metricsAccumulator.get(), connectorMetricsAccumulator.get());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
*/
package io.trino.sql.planner.planprinter;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import io.airlift.units.DataSize;
import io.airlift.units.Duration;
import io.trino.spi.Mergeable;
import io.trino.sql.planner.plan.PlanNodeId;

import java.util.List;
import java.util.Map;
import java.util.Set;

Expand All @@ -30,6 +33,7 @@
import static java.lang.Math.sqrt;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.NANOSECONDS;

public class PlanNodeStats
implements Mergeable<PlanNodeStats>
Expand Down Expand Up @@ -194,4 +198,59 @@ public PlanNodeStats mergeWith(PlanNodeStats other)
succinctBytes(this.planNodeSpilledDataSize.toBytes() + other.planNodeSpilledDataSize.toBytes()),
operatorStats);
}

@Override
public PlanNodeStats mergeWith(List<PlanNodeStats> others)
{
long planNodeInputPositions = this.planNodeInputPositions;
long planNodeOutputPositions = this.planNodeOutputPositions;
long planNodeInputDataSizeBytes = planNodeInputDataSize.toBytes();
long planNodeOutputDataSizeBytes = planNodeOutputDataSize.toBytes();
long planNodePhysicalInputDataSizeBytes = planNodePhysicalInputDataSize.toBytes();
long planNodeSpilledDataSizeBytes = planNodeSpilledDataSize.toBytes();
long planNodeScheduledTimeMillis = planNodeScheduledTime.toMillis();
long planNodeCpuTimeMillis = planNodeCpuTime.toMillis();
long planNodeBlockedTimeMillis = planNodeBlockedTime.toMillis();
double planNodePhysicalInputReadNanos = planNodePhysicalInputReadTime.getValue(NANOSECONDS);
ListMultimap<String, BasicOperatorStats> groupedOperatorStats = ArrayListMultimap.create();
for (Map.Entry<String, BasicOperatorStats> entry : this.operatorStats.entrySet()) {
groupedOperatorStats.put(entry.getKey(), entry.getValue());
}

for (PlanNodeStats other : others) {
checkArgument(planNodeId.equals(other.getPlanNodeId()), "planNodeIds do not match. %s != %s", planNodeId, other.getPlanNodeId());
planNodeInputPositions += other.planNodeInputPositions;
planNodeOutputPositions += other.planNodeOutputPositions;
planNodeScheduledTimeMillis += other.planNodeScheduledTime.toMillis();
planNodeCpuTimeMillis += other.planNodeCpuTime.toMillis();
planNodeBlockedTimeMillis += other.planNodeBlockedTime.toMillis();
planNodePhysicalInputReadNanos += other.planNodePhysicalInputReadTime.getValue(NANOSECONDS);
planNodePhysicalInputDataSizeBytes += other.planNodePhysicalInputDataSize.toBytes();
planNodeInputDataSizeBytes += other.planNodeInputDataSize.toBytes();
planNodeOutputDataSizeBytes += other.planNodeOutputDataSize.toBytes();
planNodeSpilledDataSizeBytes += other.planNodeSpilledDataSize.toBytes();
for (Map.Entry<String, BasicOperatorStats> entry : other.operatorStats.entrySet()) {
groupedOperatorStats.put(entry.getKey(), entry.getValue());
}
}

ImmutableMap.Builder<String, BasicOperatorStats> mergedOperatorStatsBuilder = ImmutableMap.builder();
for (String key : groupedOperatorStats.keySet()) {
mergedOperatorStatsBuilder.put(key, BasicOperatorStats.merge(groupedOperatorStats.get(key)));
}

return new PlanNodeStats(
planNodeId,
new Duration(planNodeScheduledTimeMillis, MILLISECONDS),
new Duration(planNodeCpuTimeMillis, MILLISECONDS),
new Duration(planNodeBlockedTimeMillis, MILLISECONDS),
planNodeInputPositions,
succinctBytes(planNodeInputDataSizeBytes),
succinctBytes(planNodePhysicalInputDataSizeBytes),
new Duration(planNodePhysicalInputReadNanos, NANOSECONDS),
planNodeOutputPositions,
succinctBytes(planNodeOutputDataSizeBytes),
succinctBytes(planNodeSpilledDataSizeBytes),
mergedOperatorStatsBuilder.buildOrThrow());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
*/
package io.trino.sql.planner.planprinter;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ListMultimap;
import io.airlift.units.Duration;
import io.trino.execution.StageInfo;
import io.trino.execution.TaskInfo;
Expand Down Expand Up @@ -51,15 +53,21 @@ public static Map<PlanNodeId, PlanNodeStats> aggregateStageStats(List<StageInfo>

public static Map<PlanNodeId, PlanNodeStats> aggregateTaskStats(List<TaskInfo> taskInfos)
{
Map<PlanNodeId, PlanNodeStats> aggregatedStats = new HashMap<>();
ListMultimap<PlanNodeId, PlanNodeStats> groupedStats = ArrayListMultimap.create();
List<PlanNodeStats> planNodeStats = taskInfos.stream()
.map(TaskInfo::getStats)
.flatMap(taskStats -> getPlanNodeStats(taskStats).stream())
.collect(toList());
for (PlanNodeStats stats : planNodeStats) {
aggregatedStats.merge(stats.getPlanNodeId(), stats, PlanNodeStats::mergeWith);
groupedStats.put(stats.getPlanNodeId(), stats);
}
return aggregatedStats;

ImmutableMap.Builder<PlanNodeId, PlanNodeStats> aggregatedStatsBuilder = ImmutableMap.builder();
for (PlanNodeId planNodeId : groupedStats.keySet()) {
List<PlanNodeStats> groupedPlanNodeStats = groupedStats.get(planNodeId);
aggregatedStatsBuilder.put(planNodeId, groupedPlanNodeStats.get(0).mergeWith(groupedPlanNodeStats.subList(1, groupedPlanNodeStats.size())));
}
return aggregatedStatsBuilder.buildOrThrow();
}

private static List<PlanNodeStats> getPlanNodeStats(TaskStats taskStats)
Expand Down

0 comments on commit 382ec92

Please sign in to comment.