From c3b913eb68544c6638b1370378f4ba67063379f4 Mon Sep 17 00:00:00 2001
From: Andrei Dan <andrei.dan@elastic.co>
Date: Tue, 26 Mar 2024 18:04:35 +0000
Subject: [PATCH] Auto sharding uses the sum of shards write loads

Data stream auto sharding uses the index write load to decide the
optimal number of shards. We read this previously from the indexing
stats output, using the `total/write_load` value however, this
proved to be wrong as that value takes into account the search shard
write load (which will always be 0).
Even more, the `total/write_load` value averages the write loads for
every shard so you can end up with indices that only have one primary
and one replica, with the primary shard having a write load of 1.7 and
the `total/write_load` reporting to be `0.8`.

For data stream auto sharding we're interested in the **total** index
write load, defined as the sum of all the shards write loads (yes we
can include the replica shard write loads in this sum as they're 0).

This PR changes the rollover write load computation to sum all the shard
write loads for the data stream write index, and in the
`DataStreamAutoShardingService` when looking at the historic write load
over the cooldown period to, again, sum the write loads of every shard
in the index metadata/stats.
---
 .../rollover/TransportRolloverAction.java     | 13 +++++++-----
 .../DataStreamAutoShardingService.java        | 21 ++-----------------
 .../DataStreamAutoShardingServiceTests.java   |  9 ++++----
 3 files changed, 14 insertions(+), 29 deletions(-)

diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
index 774bfae53fb94..e678efcbca8cb 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
@@ -250,11 +250,14 @@ protected void masterOperation(
                     final Optional<IndexStats> indexStats = Optional.ofNullable(statsResponse)
                         .map(stats -> stats.getIndex(dataStream.getWriteIndex().getName()));
 
-                    Double writeLoad = indexStats.map(stats -> stats.getTotal().getIndexing())
-                        .map(indexing -> indexing.getTotal().getWriteLoad())
-                        .orElse(null);
-
-                    rolloverAutoSharding = dataStreamAutoShardingService.calculate(clusterState, dataStream, writeLoad);
+                    Double indexWriteLoad = indexStats.map(
+                        stats -> Arrays.stream(stats.getShards())
+                            .filter(shardStats -> shardStats.getStats().indexing != null)
+                            .map(shardStats -> shardStats.getStats().indexing.getTotal().getWriteLoad())
+                            .reduce(0.0, Double::sum)
+                    ).orElse(null);
+
+                    rolloverAutoSharding = dataStreamAutoShardingService.calculate(clusterState, dataStream, indexWriteLoad);
                     logger.debug("auto sharding result for data stream [{}] is [{}]", dataStream.getName(), rolloverAutoSharding);
 
                     // if auto sharding recommends increasing the number of shards we want to trigger a rollover even if there are no
diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
index 06aec69bc97da..a045c73cc83a1 100644
--- a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
+++ b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
@@ -29,7 +29,6 @@
 import java.util.List;
 import java.util.Objects;
 import java.util.OptionalDouble;
-import java.util.OptionalLong;
 import java.util.function.Function;
 import java.util.function.LongSupplier;
 
@@ -381,27 +380,11 @@ static double getMaxIndexLoadWithinCoolingPeriod(
         // assume the current write index load is the highest observed and look back to find the actual maximum
         double maxIndexLoadWithinCoolingPeriod = writeIndexLoad;
         for (IndexWriteLoad writeLoad : writeLoadsWithinCoolingPeriod) {
-            // the IndexWriteLoad stores _for each shard_ a shard average write load ( calculated using : shard indexing time / shard
-            // uptime ) and its corresponding shard uptime
-            //
-            // to reconstruct the average _index_ write load we recalculate the shard indexing time by multiplying the shard write load
-            // to its uptime, and then, having the indexing time and uptime for each shard we calculate the average _index_ write load using
-            // (indexingTime_shard0 + indexingTime_shard1) / (uptime_shard0 + uptime_shard1)
-            // as {@link org.elasticsearch.index.shard.IndexingStats#add} does
-            double totalShardIndexingTime = 0;
-            long totalShardUptime = 0;
+            double totalIndexLoad = 0;
             for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
                 final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
-                final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
-                if (writeLoadForShard.isPresent()) {
-                    assert uptimeInMillisForShard.isPresent();
-                    double shardIndexingTime = writeLoadForShard.getAsDouble() * uptimeInMillisForShard.getAsLong();
-                    long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
-                    totalShardIndexingTime += shardIndexingTime;
-                    totalShardUptime += shardUptimeInMillis;
-                }
+                totalIndexLoad += writeLoadForShard.orElse(0);
             }
-            double totalIndexLoad = totalShardUptime == 0 ? 0.0 : (totalShardIndexingTime / totalShardUptime);
             if (totalIndexLoad > maxIndexLoadWithinCoolingPeriod) {
                 maxIndexLoadWithinCoolingPeriod = totalIndexLoad;
             }
diff --git a/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java b/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java
index bc1ec6788eec6..155cd82f63b98 100644
--- a/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java
+++ b/server/src/test/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingServiceTests.java
@@ -51,9 +51,7 @@
 import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.INCREASE_SHARDS;
 import static org.elasticsearch.action.datastreams.autosharding.AutoShardingType.NO_CHANGE_REQUIRED;
 import static org.elasticsearch.test.ClusterServiceUtils.createClusterService;
-import static org.hamcrest.Matchers.greaterThan;
 import static org.hamcrest.Matchers.is;
-import static org.hamcrest.Matchers.lessThan;
 
 public class DataStreamAutoShardingServiceTests extends ESTestCase {
 
@@ -649,7 +647,7 @@ public void testGetMaxIndexLoadWithinCoolingPeriod() {
         assertThat(maxIndexLoadWithinCoolingPeriod, is(lastIndexBeforeCoolingPeriodHasLowWriteLoad ? 5.0 : 999.0));
     }
 
-    public void testIndexLoadWithinCoolingPeriodIsShardLoadsAvg() {
+    public void testIndexLoadWithinCoolingPeriodIsSumOfShardsLoads() {
         final TimeValue coolingPeriod = TimeValue.timeValueDays(3);
 
         final Metadata.Builder metadataBuilder = Metadata.builder();
@@ -658,6 +656,8 @@ public void testIndexLoadWithinCoolingPeriodIsShardLoadsAvg() {
         final String dataStreamName = "logs";
         long now = System.currentTimeMillis();
 
+        double expectedIsSumOfShardLoads = 0.5 + 3.0 + 0.3333;
+
         for (int i = 0; i < numberOfBackingIndicesWithinCoolingPeriod; i++) {
             final long createdAt = now - (coolingPeriod.getMillis() / 2);
             IndexMetadata indexMetadata;
@@ -705,8 +705,7 @@ public void testIndexLoadWithinCoolingPeriodIsShardLoadsAvg() {
             coolingPeriod,
             () -> now
         );
-        assertThat(maxIndexLoadWithinCoolingPeriod, is(greaterThan(0.499)));
-        assertThat(maxIndexLoadWithinCoolingPeriod, is(lessThan(0.5)));
+        assertThat(maxIndexLoadWithinCoolingPeriod, is(expectedIsSumOfShardLoads));
     }
 
     public void testAutoShardingResultValidation() {