diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/DataStreamFeatures.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/DataStreamFeatures.java
index 734c10570ab2b..06dc8919360f8 100644
--- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/DataStreamFeatures.java
+++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/DataStreamFeatures.java
@@ -9,6 +9,7 @@
 package org.elasticsearch.datastreams;
 
 import org.elasticsearch.action.admin.indices.rollover.LazyRolloverAction;
+import org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService;
 import org.elasticsearch.datastreams.lifecycle.health.DataStreamLifecycleHealthInfoPublisher;
 import org.elasticsearch.features.FeatureSpecification;
 import org.elasticsearch.features.NodeFeature;
@@ -24,7 +25,8 @@ public class DataStreamFeatures implements FeatureSpecification {
     public Set<NodeFeature> getFeatures() {
         return Set.of(
             DataStreamLifecycleHealthInfoPublisher.DSL_HEALTH_INFO_FEATURE, // Added in 8.12
-            LazyRolloverAction.DATA_STREAM_LAZY_ROLLOVER                    // Added in 8.13
+            LazyRolloverAction.DATA_STREAM_LAZY_ROLLOVER,                    // Added in 8.13
+            DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE
         );
     }
 }
diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/AutoShardingCondition.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/AutoShardingCondition.java
index 174968747a338..a6e762348b3ce 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/AutoShardingCondition.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/AutoShardingCondition.java
@@ -14,7 +14,6 @@
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.core.TimeValue;
-import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.XContentBuilder;
 import org.elasticsearch.xcontent.XContentParser;
 
@@ -22,7 +21,8 @@
 
 import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingResult.CURRENT_NUMBER_OF_SHARDS;
 import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingResult.TARGET_NUMBER_OF_SHARDS;
-import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingType.SCALE_UP;
+import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingResult.WRITE_LOAD;
+import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingType.INCREASE_NUMBER_OF_SHARDS;
 
 /**
  * Condition for automatically increasing the number of shards for a data stream. The value is computed when the condition is
@@ -30,37 +30,22 @@
  */
 public class AutoShardingCondition extends Condition<AutoShardingResult> {
     public static final String NAME = "auto_sharding";
-
-    private static final ParseField WRITE_LOAD = new ParseField("write_load");
-
-    private Double writeIndexLoad;
+    private boolean isConditionMet;
 
     public AutoShardingCondition(AutoShardingResult autoShardingResult) {
-        super(NAME, Type.INTERNAL);
+        super(NAME, Type.AUTOMATIC);
         this.value = autoShardingResult;
-        this.writeIndexLoad = null;
-    }
-
-    public AutoShardingCondition(AutoShardingResult autoShardingResult, Double writeIndexLoad) {
-        super(NAME, Type.INTERNAL);
-        this.value = autoShardingResult;
-        this.writeIndexLoad = writeIndexLoad;
+        this.isConditionMet = value.type() == INCREASE_NUMBER_OF_SHARDS && value.coolDownRemaining().equals(TimeValue.ZERO);
     }
 
     public AutoShardingCondition(StreamInput in) throws IOException {
-        super(NAME, Type.INTERNAL);
+        super(NAME, Type.AUTOMATIC);
         this.value = new AutoShardingResult(in);
-        this.writeIndexLoad = in.readOptionalDouble();
     }
 
     @Override
     public Result evaluate(final Stats stats) {
-        writeIndexLoad = stats.writeIndexLoad();
-        if (value.type() == SCALE_UP && value.coolDownRemaining().equals(TimeValue.ZERO)) {
-            return new Result(this, true);
-        } else {
-            return new Result(this, false);
-        }
+        return new Result(this, isConditionMet);
     }
 
     @Override
@@ -71,17 +56,18 @@ public String getWriteableName() {
     @Override
     public void writeTo(StreamOutput out) throws IOException {
         value.writeTo(out);
-        out.writeOptionalDouble(writeIndexLoad);
     }
 
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         // we only save this representation in the cluster state as part of meet_conditions when this condition is met
-        if (value != null && value.type().equals(SCALE_UP)) {
+        if (isConditionMet) {
             builder.startObject(NAME);
             builder.field(CURRENT_NUMBER_OF_SHARDS.getPreferredName(), value.currentNumberOfShards());
             builder.field(TARGET_NUMBER_OF_SHARDS.getPreferredName(), value.targetNumberOfShards());
-            builder.field(WRITE_LOAD.getPreferredName(), writeIndexLoad);
+            assert value.writeLoad() != null
+                : "when the condition matches, a change in number of shards is executed and a write load must be present";
+            builder.field(WRITE_LOAD.getPreferredName(), value.writeLoad());
             builder.endObject();
         }
         return builder;
@@ -90,8 +76,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
     public static AutoShardingCondition fromXContent(XContentParser parser) throws IOException {
         if (parser.nextToken() == XContentParser.Token.START_OBJECT) {
             return new AutoShardingCondition(
-                new AutoShardingResult(SCALE_UP, parser.intValue(), parser.intValue(), TimeValue.ZERO),
-                parser.doubleValue()
+                new AutoShardingResult(INCREASE_NUMBER_OF_SHARDS, parser.intValue(), parser.intValue(), TimeValue.ZERO, parser.doubleValue())
             );
         } else {
             throw new IllegalArgumentException("invalid token when parsing " + NAME + " condition: " + parser.currentToken());
diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/Condition.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/Condition.java
index 44b797d7f634d..d4c59b69054a2 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/Condition.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/Condition.java
@@ -11,6 +11,7 @@
 import org.elasticsearch.TransportVersion;
 import org.elasticsearch.common.io.stream.NamedWriteable;
 import org.elasticsearch.common.unit.ByteSizeValue;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.xcontent.ToXContentFragment;
 
 import java.util.Objects;
@@ -20,13 +21,14 @@
  */
 public abstract class Condition<T> implements NamedWriteable, ToXContentFragment {
 
-    /**
-     * Describes the type of condition - a min_* condition (MIN), max_* condition (MAX), or an internal (usually) condition
+    /*
+     * Describes the type of condition - a min_* condition (MIN), max_* condition (MAX), or an automatic condition (automatic conditions
+     * are something that the platform configures and manages)
      */
     public enum Type {
         MIN,
         MAX,
-        INTERNAL
+        AUTOMATIC
     }
 
     protected T value;
@@ -91,7 +93,7 @@ public record Stats(
         ByteSizeValue indexSize,
         ByteSizeValue maxPrimaryShardSize,
         long maxPrimaryShardDocs,
-        double writeIndexLoad
+        @Nullable Double writeIndexLoad
     ) {}
 
     /**
diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/LazyRolloverAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/LazyRolloverAction.java
index 9266a320f598c..e44833fbc234b 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/LazyRolloverAction.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/LazyRolloverAction.java
@@ -9,6 +9,7 @@
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService;
 import org.elasticsearch.action.support.ActionFilters;
 import org.elasticsearch.client.internal.Client;
 import org.elasticsearch.cluster.ClusterState;
@@ -59,6 +60,7 @@ public TransportLazyRolloverAction(
             MetadataRolloverService rolloverService,
             AllocationService allocationService,
             MetadataDataStreamsService metadataDataStreamsService,
+            DataStreamAutoShardingService dataStreamAutoShardingService,
             Client client
         ) {
             super(
@@ -71,7 +73,8 @@ public TransportLazyRolloverAction(
                 rolloverService,
                 client,
                 allocationService,
-                metadataDataStreamsService
+                metadataDataStreamsService,
+                dataStreamAutoShardingService
             );
         }
 
diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/RolloverConditions.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/RolloverConditions.java
index 9f947fde14ba9..7bed68ef99a95 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/RolloverConditions.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/RolloverConditions.java
@@ -244,12 +244,12 @@ public boolean areConditionsMet(Map<String, Boolean> conditionResults) {
             .filter(c -> Condition.Type.MAX == c.type())
             .anyMatch(c -> conditionResults.getOrDefault(c.toString(), false));
 
-        boolean anyImplicitConditionsMet = conditions.values()
+        boolean anyInternalConditionsMet = conditions.values()
             .stream()
-            .filter(c -> Condition.Type.INTERNAL == c.type())
+            .filter(c -> Condition.Type.AUTOMATIC == c.type())
             .anyMatch(c -> conditionResults.getOrDefault(c.toString(), false));
 
-        return conditionResults.size() == 0 || (allMinConditionsMet && anyMaxConditionsMet) || anyImplicitConditionsMet;
+        return conditionResults.size() == 0 || (allMinConditionsMet && anyMaxConditionsMet) || anyInternalConditionsMet;
     }
 
     public static RolloverConditions fromXContent(XContentParser parser) throws IOException {
diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
index 2171aa15e2b70..48048c4e46f0d 100644
--- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
+++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/TransportRolloverAction.java
@@ -12,6 +12,7 @@
 import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.ActionType;
+import org.elasticsearch.action.admin.indices.create.CreateIndexRequest;
 import org.elasticsearch.action.admin.indices.stats.IndexStats;
 import org.elasticsearch.action.admin.indices.stats.IndicesStatsAction;
 import org.elasticsearch.action.admin.indices.stats.IndicesStatsRequest;
@@ -29,6 +30,7 @@
 import org.elasticsearch.cluster.ClusterStateTaskListener;
 import org.elasticsearch.cluster.block.ClusterBlockException;
 import org.elasticsearch.cluster.block.ClusterBlockLevel;
+import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.IndexAbstraction;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.IndexMetadataStats;
@@ -43,6 +45,7 @@
 import org.elasticsearch.common.Strings;
 import org.elasticsearch.common.collect.Iterators;
 import org.elasticsearch.common.inject.Inject;
+import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.unit.ByteSizeValue;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.core.Nullable;
@@ -64,6 +67,9 @@
 import java.util.Optional;
 import java.util.stream.Collectors;
 
+import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingType.NOT_APPLICABLE;
+import static org.elasticsearch.action.datastreams.autosharding.DataStreamAutoShardingService.AutoShardingType.INCREASE_NUMBER_OF_SHARDS;
+
 /**
  * Main class to swap the index pointed to by an alias, given some conditions
  */
@@ -74,6 +80,7 @@ public class TransportRolloverAction extends TransportMasterNodeAction<RolloverR
     private final Client client;
     private final MasterServiceTaskQueue<RolloverTask> rolloverTaskQueue;
     private final MetadataDataStreamsService metadataDataStreamsService;
+    private final DataStreamAutoShardingService dataStreamAutoShardingService;
 
     @Inject
     public TransportRolloverAction(
@@ -85,7 +92,8 @@ public TransportRolloverAction(
         MetadataRolloverService rolloverService,
         Client client,
         AllocationService allocationService,
-        MetadataDataStreamsService metadataDataStreamsService
+        MetadataDataStreamsService metadataDataStreamsService,
+        DataStreamAutoShardingService dataStreamAutoShardingService
     ) {
         this(
             RolloverAction.INSTANCE,
@@ -97,7 +105,8 @@ public TransportRolloverAction(
             rolloverService,
             client,
             allocationService,
-            metadataDataStreamsService
+            metadataDataStreamsService,
+            dataStreamAutoShardingService
         );
     }
 
@@ -111,7 +120,8 @@ public TransportRolloverAction(
         MetadataRolloverService rolloverService,
         Client client,
         AllocationService allocationService,
-        MetadataDataStreamsService metadataDataStreamsService
+        MetadataDataStreamsService metadataDataStreamsService,
+        DataStreamAutoShardingService dataStreamAutoShardingService
     ) {
         super(
             actionType.name(),
@@ -131,6 +141,7 @@ public TransportRolloverAction(
             new RolloverExecutor(clusterService, allocationService, rolloverService, threadPool)
         );
         this.metadataDataStreamsService = metadataDataStreamsService;
+        this.dataStreamAutoShardingService = dataStreamAutoShardingService;
     }
 
     @Override
@@ -171,15 +182,6 @@ protected void masterOperation(
         MetadataRolloverService.validateIndexName(clusterState, trialRolloverIndexName);
 
         boolean isDataStream = metadata.dataStreams().containsKey(rolloverRequest.getRolloverTarget());
-        final IndexAbstraction indexAbstraction = clusterState.metadata().getIndicesLookup().get(rolloverRequest.getRolloverTarget());
-        if (indexAbstraction.getType().equals(IndexAbstraction.Type.DATA_STREAM)) {
-            RolloverConditions conditionsIncludingImplicit = RolloverConditions.newBuilder(rolloverRequest.getConditions())
-                .addAutoShardingCondition(
-                    new AutoShardingResult(DataStreamAutoShardingService.AutoShardingType.SCALE_UP, 1, 3, TimeValue.ZERO)
-                )
-                .build();
-            rolloverRequest.setConditions(conditionsIncludingImplicit);
-        }
         if (rolloverRequest.isLazy()) {
             if (isDataStream == false || rolloverRequest.getConditions().hasConditions()) {
                 String message;
@@ -234,11 +236,44 @@ protected void masterOperation(
 
             listener.delegateFailureAndWrap((delegate, statsResponse) -> {
 
+                final IndexAbstraction indexAbstraction = clusterState.metadata()
+                    .getIndicesLookup()
+                    .get(rolloverRequest.getRolloverTarget());
+                Condition.Stats stats = buildStats(metadata.index(trialSourceIndexName), statsResponse);
+                if (indexAbstraction.getType().equals(IndexAbstraction.Type.DATA_STREAM)) {
+                    DataStream dataStream = (DataStream) indexAbstraction;
+                    AutoShardingResult autoShardingResult = dataStreamAutoShardingService.calculate(
+                        clusterState,
+                        dataStream,
+                        stats.writeIndexLoad()
+                    );
+                    if (autoShardingResult.type().equals(NOT_APPLICABLE) == false) {
+                        logger.debug("data stream auto sharding result is [{}]", autoShardingResult);
+                        if (autoShardingResult.type().equals(INCREASE_NUMBER_OF_SHARDS)) {
+                            if (autoShardingResult.coolDownRemaining().equals(TimeValue.ZERO)) {
+                                logger.info(
+                                    "Data stream auto sharding changing the number of shards for data stream from [{}] to [{}]",
+                                    autoShardingResult.currentNumberOfShards(),
+                                    autoShardingResult.targetNumberOfShards()
+                                );
+                                CreateIndexRequest createIndexRequest = rolloverRequest.getCreateIndexRequest();
+                                Settings settingsWithAutoSharding = Settings.builder()
+                                    .put(createIndexRequest.settings())
+                                    .put(IndexMetadata.INDEX_NUMBER_OF_SHARDS_SETTING.getKey(), autoShardingResult.targetNumberOfShards())
+                                    .build();
+                                createIndexRequest.settings(settingsWithAutoSharding);
+                            }
+                        }
+
+                        RolloverConditions conditionsIncludingImplicit = RolloverConditions.newBuilder(rolloverRequest.getConditions())
+                            .addAutoShardingCondition(autoShardingResult)
+                            .build();
+                        rolloverRequest.setConditions(conditionsIncludingImplicit);
+                    }
+                }
+
                 // Evaluate the conditions, so that we can tell without a cluster state update whether a rollover would occur.
-                final Map<String, Boolean> trialConditionResults = evaluateConditions(
-                    rolloverRequest.getConditionValues(),
-                    buildStats(metadata.index(trialSourceIndexName), statsResponse)
-                );
+                final Map<String, Boolean> trialConditionResults = evaluateConditions(rolloverRequest.getConditionValues(), stats);
 
                 final RolloverResponse trialRolloverResponse = new RolloverResponse(
                     trialSourceIndexName,
@@ -316,7 +351,7 @@ static Condition.Stats buildStats(@Nullable final IndexMetadata metadata, @Nulla
                 .max()
                 .orElse(0);
 
-            double writeLoad = 0.0;
+            Double writeLoad = null;
             if (statsResponse != null) {
                 IndexingStats indexing = statsResponse.getTotal().getIndexing();
                 if (indexing != null) {
@@ -394,16 +429,6 @@ public ClusterState executeTask(
             final var rolloverTask = rolloverTaskContext.getTask();
             final var rolloverRequest = rolloverTask.rolloverRequest();
 
-            final IndexAbstraction indexAbstraction = currentState.metadata().getIndicesLookup().get(rolloverRequest.getRolloverTarget());
-            if (indexAbstraction.getType().equals(IndexAbstraction.Type.DATA_STREAM)) {
-                RolloverConditions conditionsIncludingImplicit = RolloverConditions.newBuilder(rolloverRequest.getConditions())
-                    .addAutoShardingCondition(
-                        new AutoShardingResult(DataStreamAutoShardingService.AutoShardingType.SCALE_UP, 1, 3, TimeValue.ZERO)
-                    )
-                    .build();
-                rolloverRequest.setConditions(conditionsIncludingImplicit);
-            }
-
             // Regenerate the rollover names, as a rollover could have happened in between the pre-check and the cluster state update
             final var rolloverNames = MetadataRolloverService.resolveRolloverNames(
                 currentState,
@@ -433,6 +458,14 @@ public ClusterState executeTask(
                     ? IndexMetadataStats.fromStatsResponse(rolloverSourceIndex, rolloverTask.statsResponse())
                     : null;
 
+                final IndexAbstraction indexAbstraction = currentState.metadata()
+                    .getIndicesLookup()
+                    .get(rolloverRequest.getRolloverTarget());
+
+                if (indexAbstraction.getType().equals(IndexAbstraction.Type.DATA_STREAM)) {
+                    // TODO: we scale down the number of shards only when rolling over due to other conditions
+                }
+
                 // Perform the actual rollover
                 final var rolloverResult = rolloverService.rolloverClusterState(
                     currentState,
diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
index eb05a99838dc5..5f966cabacc8d 100644
--- a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
+++ b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
@@ -8,6 +8,9 @@
 
 package org.elasticsearch.action.datastreams.autosharding;
 
+import org.apache.logging.log4j.LogManager;
+import org.apache.logging.log4j.Logger;
+import org.elasticsearch.cluster.ClusterState;
 import org.elasticsearch.cluster.metadata.DataStream;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.metadata.IndexMetadataStats;
@@ -19,7 +22,10 @@
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Setting;
 import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.features.FeatureService;
+import org.elasticsearch.features.NodeFeature;
 import org.elasticsearch.index.Index;
 import org.elasticsearch.xcontent.ParseField;
 import org.elasticsearch.xcontent.ToXContentObject;
@@ -36,8 +42,11 @@
  */
 public class DataStreamAutoShardingService {
 
+    private static final Logger logger = LogManager.getLogger(DataStreamAutoShardingService.class);
     public static final String DATA_STREAMS_AUTO_SHARDING_ENABLED = "data_streams.auto_sharding.enabled";
 
+    public static final NodeFeature DATA_STREAM_AUTO_SHARDING_FEATURE = new NodeFeature("data_stream.auto_sharding");
+
     // TODO implement parser and take this setting into account
     public static final Setting<String> DATA_STREAMS_AUTO_SHARDING_EXCLUDES = Setting.simpleString(
         "data_streams.auto_sharding.excludes",
@@ -78,16 +87,18 @@ public class DataStreamAutoShardingService {
         Setting.Property.NodeScope
     );
     private final ClusterService clusterService;
-    private boolean isAutoShardingEnabled;
+    private final boolean isAutoShardingEnabled;
+    private final FeatureService featureService;
+    private final LongSupplier nowSupplier;
     private volatile TimeValue scaleUpCooldown;
     private volatile TimeValue scaleDownCooldown;
     private volatile int minNumberWriteThreads;
     private volatile int maxNumberWriteThreads;
 
     public enum AutoShardingType {
-        SCALE_UP,
-        SCALE_DOWN,
-        NO_SCALING,
+        INCREASE_NUMBER_OF_SHARDS,
+        DECREASES_NUMBER_OF_SHARDS,
+        NO_CHANGE_REQUIRED,
         NOT_APPLICABLE
     }
 
@@ -96,29 +107,38 @@ public enum AutoShardingType {
      * period that needs to lapse before the current recommendation should be applied.
      * <p>
      * If auto sharding is not applicable for a data stream (e.g. due to {@link #DATA_STREAMS_AUTO_SHARDING_EXCLUDES}) the target number
-     * of shards will be 0 and cool down remaining 0.
+     * of shards will be 0 and cool down remaining {@link TimeValue#MAX_VALUE}.
      */
     public record AutoShardingResult(
         AutoShardingType type,
         int currentNumberOfShards,
         int targetNumberOfShards,
-        TimeValue coolDownRemaining
+        TimeValue coolDownRemaining,
+        @Nullable Double writeLoad
     ) implements Writeable, ToXContentObject {
 
         public static final ParseField AUTO_SHARDING_TYPE = new ParseField("type");
         public static final ParseField CURRENT_NUMBER_OF_SHARDS = new ParseField("current_number_of_shards");
         public static final ParseField TARGET_NUMBER_OF_SHARDS = new ParseField("target_number_of_shards");
         public static final ParseField COOLDOWN_REMAINING = new ParseField("cool_down_remaining");
-
-        public AutoShardingResult(AutoShardingType type, int currentNumberOfShards, int targetNumberOfShards, TimeValue coolDownRemaining) {
+        public static final ParseField WRITE_LOAD = new ParseField("write_load");
+
+        public AutoShardingResult(
+            AutoShardingType type,
+            int currentNumberOfShards,
+            int targetNumberOfShards,
+            TimeValue coolDownRemaining,
+            @Nullable Double writeLoad
+        ) {
             this.type = type;
             this.currentNumberOfShards = currentNumberOfShards;
             this.targetNumberOfShards = targetNumberOfShards;
             this.coolDownRemaining = coolDownRemaining;
+            this.writeLoad = writeLoad;
         }
 
         public AutoShardingResult(StreamInput in) throws IOException {
-            this(in.readEnum(AutoShardingType.class), in.readVInt(), in.readVInt(), in.readTimeValue());
+            this(in.readEnum(AutoShardingType.class), in.readVInt(), in.readVInt(), in.readTimeValue(), in.readOptionalDouble());
         }
 
         @Override
@@ -128,6 +148,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
             builder.field(CURRENT_NUMBER_OF_SHARDS.getPreferredName(), currentNumberOfShards);
             builder.field(TARGET_NUMBER_OF_SHARDS.getPreferredName(), targetNumberOfShards);
             builder.field(COOLDOWN_REMAINING.getPreferredName(), coolDownRemaining.toHumanReadableString(2));
+            builder.field(WRITE_LOAD.getPreferredName(), writeLoad);
             builder.endObject();
             return builder;
         }
@@ -142,6 +163,8 @@ public String toString() {
                 + targetNumberOfShards
                 + ", coolDownRemaining: "
                 + coolDownRemaining
+                + ", writeLoad: "
+                + writeLoad
                 + " }";
         }
 
@@ -151,16 +174,24 @@ public void writeTo(StreamOutput out) throws IOException {
             out.writeVInt(currentNumberOfShards);
             out.writeVInt(targetNumberOfShards);
             out.writeTimeValue(coolDownRemaining);
+            out.writeOptionalDouble(writeLoad);
         }
     }
 
-    public DataStreamAutoShardingService(Settings settings, ClusterService clusterService) {
+    public DataStreamAutoShardingService(
+        Settings settings,
+        ClusterService clusterService,
+        FeatureService featureService,
+        LongSupplier nowSupplier
+    ) {
         this.clusterService = clusterService;
         this.isAutoShardingEnabled = settings.getAsBoolean(DATA_STREAMS_AUTO_SHARDING_ENABLED, false);
         this.scaleUpCooldown = DATA_STREAMS_AUTO_SHARDING_SCALE_UP_COOLDOWN.get(settings);
         this.scaleDownCooldown = DATA_STREAMS_AUTO_SHARDING_SCALE_DOWN_COOLDOWN.get(settings);
         this.minNumberWriteThreads = CLUSTER_AUTO_SHARDING_MIN_NUMBER_WRITE_THREADS.get(settings);
         this.maxNumberWriteThreads = CLUSTER_AUTO_SHARDING_MAX_NUMBER_WRITE_THREADS.get(settings);
+        this.featureService = featureService;
+        this.nowSupplier = nowSupplier;
     }
 
     public void init() {
@@ -175,20 +206,45 @@ public void init() {
     }
 
     /**
-     * Computes the optimal number of shards for the provided data stream according to the indexing load.
+     * Computes the optimal number of shards for the provided data stream according to the write index's indexing load (to check if we must
+     * increase the number of shards, whilst the heuristics for decreasing the number of shards _might_ use the provide write indexing
+     * load).
      * The result type will indicate the recommendation of the auto sharding service :
      * - not applicable if the data stream is excluded from auto sharding as configured by {@link #DATA_STREAMS_AUTO_SHARDING_EXCLUDES} or
-     * if the auto sharding functionality is disabled according to {@link #DATA_STREAMS_AUTO_SHARDING_EXCLUDES}
-     * - scale up if the number of shards it deems necessary for the provided data stream is GT the current number of shards
-     * - scale down if the number of shards it deems necessary for the provided data stream is LT the current number of shards
+     * if the auto sharding functionality is disabled according to {@link #DATA_STREAMS_AUTO_SHARDING_EXCLUDES}, or if the cluster
+     * doesn't have the feature available
+     * - increase number of shards if the optimal number of shards it deems necessary for the provided data stream is GT the current number
+     * of shards
+     * - decrease the number of shards if the optimal number of shards it deems necessary for the provided data stream is LT the current
+     * number of shards
      */
-    public AutoShardingResult calculate(Metadata metadata, DataStream dataStream, double writeIndexLoad) {
+    public AutoShardingResult calculate(ClusterState state, DataStream dataStream, @Nullable Double writeIndexLoad) {
+        Metadata metadata = state.metadata();
         if (isAutoShardingEnabled == false) {
-            IndexMetadata writeIndex = metadata.index(dataStream.getWriteIndex());
-            assert writeIndex != null : "the data stream write index must exist in the provided cluster metadata";
-            return new AutoShardingResult(AutoShardingType.NOT_APPLICABLE, writeIndex.getNumberOfShards(), 0, TimeValue.ZERO);
+            logger.debug("Data stream auto sharding service is not enabled.");
+            return new AutoShardingResult(AutoShardingType.NOT_APPLICABLE, 0, 0, TimeValue.MAX_VALUE, null);
         }
-        return null;
+
+        if (featureService.clusterHasFeature(state, DataStreamAutoShardingService.DATA_STREAM_AUTO_SHARDING_FEATURE) == false) {
+            logger.debug(
+                "Data stream auto sharding service cannot compute the optimal number of shards for data stream [{}] because the cluster "
+                    + "doesn't have the auto sharding feature",
+                dataStream.getName()
+            );
+            return new AutoShardingResult(AutoShardingType.NOT_APPLICABLE, 0, 0, TimeValue.MAX_VALUE, null);
+        }
+
+        // TODO validate the data stream against DATA_STREAMS_AUTO_SHARDING_EXCLUDES
+
+        if (writeIndexLoad == null) {
+            logger.debug(
+                "Data stream auto sharding service cannot compute the optimal number of shards for data stream [{}] as the write index "
+                    + "load is not available",
+                dataStream.getName()
+            );
+            return new AutoShardingResult(AutoShardingType.NOT_APPLICABLE, 0, 0, TimeValue.MAX_VALUE, null);
+        }
+        return innerCalculate(metadata, dataStream, writeIndexLoad, nowSupplier);
     }
 
     // visible for testing
@@ -208,10 +264,11 @@ AutoShardingResult innerCalculate(Metadata metadata, DataStream dataStream, doub
             );
 
             return new AutoShardingResult(
-                AutoShardingType.SCALE_UP,
+                AutoShardingType.INCREASE_NUMBER_OF_SHARDS,
                 writeIndex.getNumberOfShards(),
                 Math.toIntExact(optimalNumberOfShards),
-                remainingCoolDown
+                remainingCoolDown,
+                writeIndexLoad
             );
         }
 
@@ -253,10 +310,11 @@ AutoShardingResult innerCalculate(Metadata metadata, DataStream dataStream, doub
 
         if (scaleDownNumberOfShards < writeIndex.getNumberOfShards()) {
             return new AutoShardingResult(
-                AutoShardingType.SCALE_DOWN,
+                AutoShardingType.DECREASES_NUMBER_OF_SHARDS,
                 writeIndex.getNumberOfShards(),
                 Math.toIntExact(scaleDownNumberOfShards),
-                TimeValue.timeValueMillis(Math.max(0L, scaleDownCooldown.millis() - timeSinceLastAutoShardingEvent.millis()))
+                TimeValue.timeValueMillis(Math.max(0L, scaleDownCooldown.millis() - timeSinceLastAutoShardingEvent.millis())),
+                maxIndexLoadWithinCoolingPeriod
             );
         }
 
@@ -282,23 +340,19 @@ List<Index> getIndicesCreatedWithin(Metadata metadata, DataStream dataStream, Ti
         return dataStreamIndices.subList(firstIndexWithinAgeRange, dataStreamIndices.size());
     }
 
-    public void updateAutoShardingEnabled(boolean autoShardingEnabled) {
-        this.isAutoShardingEnabled = autoShardingEnabled;
-    }
-
-    public void updateScaleUpCooldown(TimeValue scaleUpCooldown) {
+    void updateScaleUpCooldown(TimeValue scaleUpCooldown) {
         this.scaleUpCooldown = scaleUpCooldown;
     }
 
-    public void updateScaleDownCooldown(TimeValue scaleDownCooldown) {
+    void updateScaleDownCooldown(TimeValue scaleDownCooldown) {
         this.scaleDownCooldown = scaleDownCooldown;
     }
 
-    public void updateMinNumberWriteThreads(int minNumberWriteThreads) {
+    void updateMinNumberWriteThreads(int minNumberWriteThreads) {
         this.minNumberWriteThreads = minNumberWriteThreads;
     }
 
-    public void updateMaxNumberWriteThreads(int maxNumberWriteThreads) {
+    void updateMaxNumberWriteThreads(int maxNumberWriteThreads) {
         this.maxNumberWriteThreads = maxNumberWriteThreads;
     }
 }