diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java index 81762def4cc35..febed3d97efbb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlMetadata.java @@ -57,8 +57,9 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom { public static final String TYPE = "ml"; private static final ParseField JOBS_FIELD = new ParseField("jobs"); private static final ParseField DATAFEEDS_FIELD = new ParseField("datafeeds"); + private static final ParseField LAST_MEMORY_REFRESH_VERSION_FIELD = new ParseField("last_memory_refresh_version"); - public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap()); + public static final MlMetadata EMPTY_METADATA = new MlMetadata(Collections.emptySortedMap(), Collections.emptySortedMap(), null); // This parser follows the pattern that metadata is parsed leniently (to allow for enhancements) public static final ObjectParser LENIENT_PARSER = new ObjectParser<>("ml_metadata", true, Builder::new); @@ -66,15 +67,18 @@ public class MlMetadata implements XPackPlugin.XPackMetaDataCustom { LENIENT_PARSER.declareObjectArray(Builder::putJobs, (p, c) -> Job.LENIENT_PARSER.apply(p, c).build(), JOBS_FIELD); LENIENT_PARSER.declareObjectArray(Builder::putDatafeeds, (p, c) -> DatafeedConfig.LENIENT_PARSER.apply(p, c).build(), DATAFEEDS_FIELD); + LENIENT_PARSER.declareLong(Builder::setLastMemoryRefreshVersion, LAST_MEMORY_REFRESH_VERSION_FIELD); } private final SortedMap jobs; private final SortedMap datafeeds; + private final Long lastMemoryRefreshVersion; private final GroupOrJobLookup groupOrJobLookup; - private MlMetadata(SortedMap jobs, SortedMap datafeeds) { + private MlMetadata(SortedMap jobs, SortedMap datafeeds, Long lastMemoryRefreshVersion) { this.jobs = Collections.unmodifiableSortedMap(jobs); this.datafeeds = Collections.unmodifiableSortedMap(datafeeds); + this.lastMemoryRefreshVersion = lastMemoryRefreshVersion; this.groupOrJobLookup = new GroupOrJobLookup(jobs.values()); } @@ -112,6 +116,10 @@ public Set expandDatafeedIds(String expression, boolean allowNoDatafeeds .expand(expression, allowNoDatafeeds); } + public Long getLastMemoryRefreshVersion() { + return lastMemoryRefreshVersion; + } + @Override public Version getMinimalSupportedVersion() { return Version.V_5_4_0; @@ -145,7 +153,11 @@ public MlMetadata(StreamInput in) throws IOException { datafeeds.put(in.readString(), new DatafeedConfig(in)); } this.datafeeds = datafeeds; - + if (in.getVersion().onOrAfter(Version.V_6_6_0)) { + lastMemoryRefreshVersion = in.readOptionalLong(); + } else { + lastMemoryRefreshVersion = null; + } this.groupOrJobLookup = new GroupOrJobLookup(jobs.values()); } @@ -153,6 +165,9 @@ public MlMetadata(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { writeMap(jobs, out); writeMap(datafeeds, out); + if (out.getVersion().onOrAfter(Version.V_6_6_0)) { + out.writeOptionalLong(lastMemoryRefreshVersion); + } } private static void writeMap(Map map, StreamOutput out) throws IOException { @@ -169,6 +184,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws new DelegatingMapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"), params); mapValuesToXContent(JOBS_FIELD, jobs, builder, extendedParams); mapValuesToXContent(DATAFEEDS_FIELD, datafeeds, builder, extendedParams); + if (lastMemoryRefreshVersion != null) { + builder.field(LAST_MEMORY_REFRESH_VERSION_FIELD.getPreferredName(), lastMemoryRefreshVersion); + } return builder; } @@ -185,30 +203,46 @@ public static class MlMetadataDiff implements NamedDiff { final Diff> jobs; final Diff> datafeeds; + final Long lastMemoryRefreshVersion; MlMetadataDiff(MlMetadata before, MlMetadata after) { this.jobs = DiffableUtils.diff(before.jobs, after.jobs, DiffableUtils.getStringKeySerializer()); this.datafeeds = DiffableUtils.diff(before.datafeeds, after.datafeeds, DiffableUtils.getStringKeySerializer()); + this.lastMemoryRefreshVersion = after.lastMemoryRefreshVersion; } public MlMetadataDiff(StreamInput in) throws IOException { this.jobs = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), Job::new, MlMetadataDiff::readJobDiffFrom); this.datafeeds = DiffableUtils.readJdkMapDiff(in, DiffableUtils.getStringKeySerializer(), DatafeedConfig::new, - MlMetadataDiff::readSchedulerDiffFrom); + MlMetadataDiff::readDatafeedDiffFrom); + if (in.getVersion().onOrAfter(Version.V_6_6_0)) { + lastMemoryRefreshVersion = in.readOptionalLong(); + } else { + lastMemoryRefreshVersion = null; + } } + /** + * Merge the diff with the ML metadata. + * @param part The current ML metadata. + * @return The new ML metadata. + */ @Override public MetaData.Custom apply(MetaData.Custom part) { TreeMap newJobs = new TreeMap<>(jobs.apply(((MlMetadata) part).jobs)); TreeMap newDatafeeds = new TreeMap<>(datafeeds.apply(((MlMetadata) part).datafeeds)); - return new MlMetadata(newJobs, newDatafeeds); + // lastMemoryRefreshVersion always comes from the diff - no need to merge with the old value + return new MlMetadata(newJobs, newDatafeeds, lastMemoryRefreshVersion); } @Override public void writeTo(StreamOutput out) throws IOException { jobs.writeTo(out); datafeeds.writeTo(out); + if (out.getVersion().onOrAfter(Version.V_6_6_0)) { + out.writeOptionalLong(lastMemoryRefreshVersion); + } } @Override @@ -220,7 +254,7 @@ static Diff readJobDiffFrom(StreamInput in) throws IOException { return AbstractDiffable.readDiffFrom(Job::new, in); } - static Diff readSchedulerDiffFrom(StreamInput in) throws IOException { + static Diff readDatafeedDiffFrom(StreamInput in) throws IOException { return AbstractDiffable.readDiffFrom(DatafeedConfig::new, in); } } @@ -233,7 +267,8 @@ public boolean equals(Object o) { return false; MlMetadata that = (MlMetadata) o; return Objects.equals(jobs, that.jobs) && - Objects.equals(datafeeds, that.datafeeds); + Objects.equals(datafeeds, that.datafeeds) && + Objects.equals(lastMemoryRefreshVersion, that.lastMemoryRefreshVersion); } @Override @@ -243,13 +278,14 @@ public final String toString() { @Override public int hashCode() { - return Objects.hash(jobs, datafeeds); + return Objects.hash(jobs, datafeeds, lastMemoryRefreshVersion); } public static class Builder { private TreeMap jobs; private TreeMap datafeeds; + private Long lastMemoryRefreshVersion; public Builder() { jobs = new TreeMap<>(); @@ -263,6 +299,7 @@ public Builder(@Nullable MlMetadata previous) { } else { jobs = new TreeMap<>(previous.jobs); datafeeds = new TreeMap<>(previous.datafeeds); + lastMemoryRefreshVersion = previous.lastMemoryRefreshVersion; } } @@ -382,8 +419,13 @@ private Builder putDatafeeds(Collection datafeeds) { return this; } + public Builder setLastMemoryRefreshVersion(Long lastMemoryRefreshVersion) { + this.lastMemoryRefreshVersion = lastMemoryRefreshVersion; + return this; + } + public MlMetadata build() { - return new MlMetadata(jobs, datafeeds); + return new MlMetadata(jobs, datafeeds, lastMemoryRefreshVersion); } public void markJobAsDeleting(String jobId, PersistentTasksCustomMetaData tasks, boolean allowDeleteOpenJob) { @@ -420,8 +462,6 @@ void checkJobHasNoDatafeed(String jobId) { } } - - public static MlMetadata getMlMetadata(ClusterState state) { MlMetadata mlMetadata = (state == null) ? null : state.getMetaData().custom(TYPE); if (mlMetadata == null) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java index ffe24dff8ced0..fd7fa70bded43 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/config/Job.java @@ -142,6 +142,7 @@ private static ObjectParser createParser(boolean ignoreUnknownFie private final Date createTime; private final Date finishedTime; private final Date lastDataTime; + // TODO: Remove in 7.0 private final Long establishedModelMemory; private final AnalysisConfig analysisConfig; private final AnalysisLimits analysisLimits; @@ -439,6 +440,7 @@ public Collection allInputFields() { * program code and stack. * @return an estimate of the memory requirement of this job, in bytes */ + // TODO: remove this method in 7.0 public long estimateMemoryFootprint() { if (establishedModelMemory != null && establishedModelMemory > 0) { return establishedModelMemory + PROCESS_MEMORY_OVERHEAD.getBytes(); @@ -658,6 +660,7 @@ public static class Builder implements Writeable, ToXContentObject { private Date createTime; private Date finishedTime; private Date lastDataTime; + // TODO: remove in 7.0 private Long establishedModelMemory; private ModelPlotConfig modelPlotConfig; private Long renormalizationWindowDays; @@ -1102,10 +1105,6 @@ private void validateGroups() { public Job build(Date createTime) { setCreateTime(createTime); setJobVersion(Version.CURRENT); - // TODO: Maybe we _could_ accept a value for this supplied at create time - it would - // mean cloned jobs that hadn't been edited much would start with an accurate expected size. - // But on the other hand it would mean jobs that were cloned and then completely changed - // would start with a size that was completely wrong. setEstablishedModelMemory(null); return build(); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java index 62340ba6cf63c..0fae85f6d6b5c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/job/config/JobTests.java @@ -561,7 +561,7 @@ public void testEstimateMemoryFootprint_GivenNoLimitAndNotEstablished() { builder.setEstablishedModelMemory(0L); } assertEquals(ByteSizeUnit.MB.toBytes(AnalysisLimits.PRE_6_1_DEFAULT_MODEL_MEMORY_LIMIT_MB) - + Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint()); + + Job.PROCESS_MEMORY_OVERHEAD.getBytes(), builder.build().estimateMemoryFootprint()); } public void testEarliestValidTimestamp_GivenEmptyDataCounts() { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 77f898dcb8773..9dbf5cc9f6d16 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -181,6 +181,7 @@ import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerFactory; import org.elasticsearch.xpack.ml.job.process.normalizer.NormalizerProcessFactory; import org.elasticsearch.xpack.ml.notifications.Auditor; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.process.NativeController; import org.elasticsearch.xpack.ml.process.NativeControllerHolder; import org.elasticsearch.xpack.ml.rest.RestDeleteExpiredDataAction; @@ -278,6 +279,7 @@ public class MachineLearning extends Plugin implements ActionPlugin, AnalysisPlu private final SetOnce autodetectProcessManager = new SetOnce<>(); private final SetOnce datafeedManager = new SetOnce<>(); + private final SetOnce memoryTracker = new SetOnce<>(); public MachineLearning(Settings settings, Path configPath) { this.settings = settings; @@ -420,6 +422,8 @@ public Collection createComponents(Client client, ClusterService cluster this.datafeedManager.set(datafeedManager); MlLifeCycleService mlLifeCycleService = new MlLifeCycleService(environment, clusterService, datafeedManager, autodetectProcessManager); + MlMemoryTracker memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider); + this.memoryTracker.set(memoryTracker); // This object's constructor attaches to the license state, so there's no need to retain another reference to it new InvalidLicenseEnforcer(getLicenseState(), threadPool, datafeedManager, autodetectProcessManager); @@ -438,7 +442,8 @@ public Collection createComponents(Client client, ClusterService cluster jobDataCountsPersister, datafeedManager, auditor, - new MlAssignmentNotifier(auditor, clusterService) + new MlAssignmentNotifier(auditor, clusterService), + memoryTracker ); } @@ -449,7 +454,8 @@ public List> getPersistentTasksExecutor(ClusterServic } return Arrays.asList( - new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get()), + new TransportOpenJobAction.OpenJobPersistentTasksExecutor(settings, clusterService, autodetectProcessManager.get(), + memoryTracker.get()), new TransportStartDatafeedAction.StartDatafeedPersistentTasksExecutor(settings, datafeedManager.get()) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java index 761c21b63f165..10e6b8093f763 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportDeleteJobAction.java @@ -69,6 +69,7 @@ import org.elasticsearch.xpack.ml.job.persistence.JobDataDeleter; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; import org.elasticsearch.xpack.ml.notifications.Auditor; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.utils.MlIndicesUtils; import java.util.ArrayList; @@ -94,6 +95,7 @@ public class TransportDeleteJobAction extends TransportMasterNodeAction(); } @@ -211,6 +215,9 @@ private void normalDeleteJob(ParentTaskAssigningClient parentTaskClient, DeleteJ ActionListener listener) { String jobId = request.getJobId(); + // We clean up the memory tracker on delete rather than close as close is not a master node action + memoryTracker.removeJob(jobId); + // Step 4. When the job has been removed from the cluster state, return a response // ------- CheckedConsumer apiResponseHandler = jobDeleted -> { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java index 5b4be78596a64..ce820db0babec 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportOpenJobAction.java @@ -69,6 +69,7 @@ import org.elasticsearch.xpack.ml.job.persistence.JobConfigProvider; import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcessManager; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import java.io.IOException; import java.util.ArrayList; @@ -95,20 +96,23 @@ To ensure that a subsequent close job call will see that same task status (and s */ public class TransportOpenJobAction extends TransportMasterNodeAction { + private static final PersistentTasksCustomMetaData.Assignment AWAITING_LAZY_ASSIGNMENT = + new PersistentTasksCustomMetaData.Assignment(null, "persistent task is awaiting node assignment."); + private final XPackLicenseState licenseState; private final PersistentTasksService persistentTasksService; private final Client client; private final JobResultsProvider jobResultsProvider; private final JobConfigProvider jobConfigProvider; - private static final PersistentTasksCustomMetaData.Assignment AWAITING_LAZY_ASSIGNMENT = - new PersistentTasksCustomMetaData.Assignment(null, "persistent task is awaiting node assignment."); + private final MlMemoryTracker memoryTracker; @Inject public TransportOpenJobAction(Settings settings, TransportService transportService, ThreadPool threadPool, XPackLicenseState licenseState, ClusterService clusterService, PersistentTasksService persistentTasksService, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, Client client, - JobResultsProvider jobResultsProvider, JobConfigProvider jobConfigProvider) { + JobResultsProvider jobResultsProvider, JobConfigProvider jobConfigProvider, + MlMemoryTracker memoryTracker) { super(settings, OpenJobAction.NAME, transportService, clusterService, threadPool, actionFilters, indexNameExpressionResolver, OpenJobAction.Request::new); this.licenseState = licenseState; @@ -116,6 +120,7 @@ public TransportOpenJobAction(Settings settings, TransportService transportServi this.client = client; this.jobResultsProvider = jobResultsProvider; this.jobConfigProvider = jobConfigProvider; + this.memoryTracker = memoryTracker; } /** @@ -144,6 +149,7 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j int maxConcurrentJobAllocations, int fallbackMaxNumberOfOpenJobs, int maxMachineMemoryPercent, + MlMemoryTracker memoryTracker, Logger logger) { String resultsIndexName = job != null ? job.getResultsIndexName() : null; List unavailableIndices = verifyIndicesPrimaryShardsAreActive(resultsIndexName, clusterState); @@ -154,10 +160,38 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j return new PersistentTasksCustomMetaData.Assignment(null, reason); } + // Try to allocate jobs according to memory usage, but if that's not possible (maybe due to a mixed version cluster or maybe + // because of some weird OS problem) then fall back to the old mechanism of only considering numbers of assigned jobs + boolean allocateByMemory = true; + + if (memoryTracker.isRecentlyRefreshed() == false) { + + boolean scheduledRefresh = memoryTracker.asyncRefresh(ActionListener.wrap( + acknowledged -> { + if (acknowledged) { + logger.trace("Job memory requirement refresh request completed successfully"); + } else { + logger.warn("Job memory requirement refresh request completed but did not set time in cluster state"); + } + }, + e -> logger.error("Failed to refresh job memory requirements", e) + )); + if (scheduledRefresh) { + String reason = "Not opening job [" + jobId + "] because job memory requirements are stale - refresh requested"; + logger.debug(reason); + return new PersistentTasksCustomMetaData.Assignment(null, reason); + } else { + allocateByMemory = false; + logger.warn("Falling back to allocating job [{}] by job counts because a memory requirement refresh could not be scheduled", + jobId); + } + } + List reasons = new LinkedList<>(); long maxAvailableCount = Long.MIN_VALUE; + long maxAvailableMemory = Long.MIN_VALUE; DiscoveryNode minLoadedNodeByCount = null; - + DiscoveryNode minLoadedNodeByMemory = null; PersistentTasksCustomMetaData persistentTasks = clusterState.getMetaData().custom(PersistentTasksCustomMetaData.TYPE); for (DiscoveryNode node : clusterState.getNodes()) { Map nodeAttributes = node.getAttributes(); @@ -197,10 +231,9 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j } } - long numberOfAssignedJobs = 0; int numberOfAllocatingJobs = 0; - + long assignedJobMemory = 0; if (persistentTasks != null) { // find all the job tasks assigned to this node Collection> assignedTasks = persistentTasks.findTasks( @@ -231,6 +264,15 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j if (jobState.isAnyOf(JobState.CLOSED, JobState.FAILED) == false) { // Don't count CLOSED or FAILED jobs, as they don't consume native memory ++numberOfAssignedJobs; + OpenJobAction.JobParams params = (OpenJobAction.JobParams) assignedTask.getParams(); + Long jobMemoryRequirement = memoryTracker.getJobMemoryRequirement(params.getJobId()); + if (jobMemoryRequirement == null) { + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because " + + "the memory requirement for job [{}] was not available", jobId, params.getJobId()); + } else { + assignedJobMemory += jobMemoryRequirement; + } } } } @@ -271,10 +313,62 @@ static PersistentTasksCustomMetaData.Assignment selectLeastLoadedMlNode(String j maxAvailableCount = availableCount; minLoadedNodeByCount = node; } + + String machineMemoryStr = nodeAttributes.get(MachineLearning.MACHINE_MEMORY_NODE_ATTR); + long machineMemory = -1; + // TODO: remove leniency and reject the node if the attribute is null in 7.0 + if (machineMemoryStr != null) { + try { + machineMemory = Long.parseLong(machineMemoryStr); + } catch (NumberFormatException e) { + String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + "], because " + + MachineLearning.MACHINE_MEMORY_NODE_ATTR + " attribute [" + machineMemoryStr + "] is not a long"; + logger.trace(reason); + reasons.add(reason); + continue; + } + } + + if (allocateByMemory) { + if (machineMemory > 0) { + long maxMlMemory = machineMemory * maxMachineMemoryPercent / 100; + Long estimatedMemoryFootprint = memoryTracker.getJobMemoryRequirement(jobId); + if (estimatedMemoryFootprint != null) { + long availableMemory = maxMlMemory - assignedJobMemory; + if (estimatedMemoryFootprint > availableMemory) { + String reason = "Not opening job [" + jobId + "] on node [" + nodeNameAndMlAttributes(node) + + "], because this node has insufficient available memory. Available memory for ML [" + maxMlMemory + + "], memory required by existing jobs [" + assignedJobMemory + + "], estimated memory required for this job [" + estimatedMemoryFootprint + "]"; + logger.trace(reason); + reasons.add(reason); + continue; + } + + if (maxAvailableMemory < availableMemory) { + maxAvailableMemory = availableMemory; + minLoadedNodeByMemory = node; + } + } else { + // If we cannot get the job memory requirement, + // fall back to simply allocating by job count + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because its memory requirement was not available", + jobId); + } + } else { + // If we cannot get the available memory on any machine in + // the cluster, fall back to simply allocating by job count + allocateByMemory = false; + logger.debug("Falling back to allocating job [{}] by job counts because machine memory was not available for node [{}]", + jobId, nodeNameAndMlAttributes(node)); + } + } } - if (minLoadedNodeByCount != null) { - logger.debug("selected node [{}] for job [{}]", minLoadedNodeByCount, jobId); - return new PersistentTasksCustomMetaData.Assignment(minLoadedNodeByCount.getId(), ""); + DiscoveryNode minLoadedNode = allocateByMemory ? minLoadedNodeByMemory : minLoadedNodeByCount; + if (minLoadedNode != null) { + logger.debug("selected node [{}] for job [{}]", minLoadedNode, jobId); + return new PersistentTasksCustomMetaData.Assignment(minLoadedNode.getId(), ""); } else { String explanation = String.join("|", reasons); logger.debug("no node selected for job [{}], reasons [{}]", jobId, explanation); @@ -415,6 +509,11 @@ protected void masterOperation(OpenJobAction.Request request, ClusterState state OpenJobAction.JobParams jobParams = request.getJobParams(); if (licenseState.isMachineLearningAllowed()) { + // If the whole cluster supports the ML memory tracker then we don't need + // to worry about updating established model memory on the job objects + // TODO: remove in 7.0 as it will always be true + boolean clusterSupportsMlMemoryTracker = state.getNodes().getMinNodeVersion().onOrAfter(Version.V_6_6_0); + // Clear job finished time once the job is started and respond ActionListener clearJobFinishTime = ActionListener.wrap( response -> { @@ -446,15 +545,19 @@ public void onFailure(Exception e) { }; // Start job task - ActionListener jobUpateListener = ActionListener.wrap( - response -> { - persistentTasksService.sendStartRequest(MlTasks.jobTaskId(jobParams.getJobId()), - MlTasks.JOB_TASK_NAME, jobParams, waitForJobToStart); - }, - listener::onFailure + ActionListener memoryRequirementRefreshListener = ActionListener.wrap( + mem -> persistentTasksService.sendStartRequest(MlTasks.jobTaskId(jobParams.getJobId()), MlTasks.JOB_TASK_NAME, jobParams, + waitForJobToStart), + listener::onFailure + ); + + // Tell the job tracker to refresh the memory requirement for this job and all other jobs that have persistent tasks + ActionListener jobUpdateListener = ActionListener.wrap( + response -> memoryTracker.refreshJobMemoryAndAllOthers(jobParams.getJobId(), memoryRequirementRefreshListener), + listener::onFailure ); - // Update established model memory for pre-6.1 jobs that haven't had it set + // Update established model memory for pre-6.1 jobs that haven't had it set (TODO: remove in 7.0) // and increase the model memory limit for 6.1 - 6.3 jobs ActionListener missingMappingsListener = ActionListener.wrap( response -> { @@ -462,8 +565,9 @@ public void onFailure(Exception e) { if (job != null) { Version jobVersion = job.getJobVersion(); Long jobEstablishedModelMemory = job.getEstablishedModelMemory(); - if ((jobVersion == null || jobVersion.before(Version.V_6_1_0)) + if (clusterSupportsMlMemoryTracker == false && (jobVersion == null || jobVersion.before(Version.V_6_1_0)) && (jobEstablishedModelMemory == null || jobEstablishedModelMemory == 0)) { + // TODO: remove in 7.0 - established model memory no longer needs to be set on the job object // Set the established memory usage for pre 6.1 jobs jobResultsProvider.getEstablishedMemoryUsage(job.getId(), null, null, establishedModelMemory -> { if (establishedModelMemory != null && establishedModelMemory > 0) { @@ -472,9 +576,9 @@ public void onFailure(Exception e) { UpdateJobAction.Request updateRequest = UpdateJobAction.Request.internal(job.getId(), update); executeAsyncWithOrigin(client, ML_ORIGIN, UpdateJobAction.INSTANCE, updateRequest, - jobUpateListener); + jobUpdateListener); } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } }, listener::onFailure); } else if (jobVersion != null && @@ -491,16 +595,16 @@ public void onFailure(Exception e) { .setAnalysisLimits(limits).build(); UpdateJobAction.Request updateRequest = UpdateJobAction.Request.internal(job.getId(), update); executeAsyncWithOrigin(client, ML_ORIGIN, UpdateJobAction.INSTANCE, updateRequest, - jobUpateListener); + jobUpdateListener); } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } } else { - jobUpateListener.onResponse(null); + jobUpdateListener.onResponse(null); } }, listener::onFailure ); @@ -644,6 +748,7 @@ private void addDocMappingIfMissing(String alias, CheckedSupplier { private final AutodetectProcessManager autodetectProcessManager; + private final MlMemoryTracker memoryTracker; /** * The maximum number of open jobs can be different on each node. However, nodes on older versions @@ -657,9 +762,10 @@ public static class OpenJobPersistentTasksExecutor extends PersistentTasksExecut private volatile int maxLazyMLNodes; public OpenJobPersistentTasksExecutor(Settings settings, ClusterService clusterService, - AutodetectProcessManager autodetectProcessManager) { + AutodetectProcessManager autodetectProcessManager, MlMemoryTracker memoryTracker) { super(MlTasks.JOB_TASK_NAME, MachineLearning.UTILITY_THREAD_POOL_NAME); this.autodetectProcessManager = autodetectProcessManager; + this.memoryTracker = memoryTracker; this.fallbackMaxNumberOfOpenJobs = AutodetectProcessManager.MAX_OPEN_JOBS_PER_NODE.get(settings); this.maxConcurrentJobAllocations = MachineLearning.CONCURRENT_JOB_ALLOCATIONS.get(settings); this.maxMachineMemoryPercent = MachineLearning.MAX_MACHINE_MEMORY_PERCENT.get(settings); @@ -679,10 +785,11 @@ public PersistentTasksCustomMetaData.Assignment getAssignment(OpenJobAction.JobP maxConcurrentJobAllocations, fallbackMaxNumberOfOpenJobs, maxMachineMemoryPercent, + memoryTracker, logger); if (assignment.getExecutorNode() == null) { int numMlNodes = 0; - for(DiscoveryNode node : clusterState.getNodes()) { + for (DiscoveryNode node : clusterState.getNodes()) { if (Boolean.valueOf(node.getAttributes().get(MachineLearning.ML_ENABLED_NODE_ATTR))) { numMlNodes++; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java index 9ae247cca73b6..727732540e72b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/AutoDetectResultProcessor.java @@ -111,6 +111,7 @@ public class AutoDetectResultProcessor { * New model size stats are read as the process is running */ private volatile ModelSizeStats latestModelSizeStats; + // TODO: remove in 7.0, along with all established model memory functionality in this class private volatile Date latestDateForEstablishedModelMemoryCalc; private volatile long latestEstablishedModelMemory; private volatile boolean haveNewLatestModelSizeStats; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java new file mode 100644 index 0000000000000..2ea92d98391eb --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/process/MlMemoryTracker.java @@ -0,0 +1,331 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.process; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceNotFoundException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.LocalNodeMasterListener; +import org.elasticsearch.cluster.ack.AckedRequest; +import org.elasticsearch.cluster.metadata.MetaData; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.job.JobManager; +import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; + +/** + * This class keeps track of the memory requirement of ML jobs. + * It only functions on the master node - for this reason it should only be used by master node actions. + * The memory requirement for ML jobs can be updated in 3 ways: + * 1. For all open ML jobs (via {@link #asyncRefresh}) + * 2. For all open ML jobs, plus one named ML job that is not open (via {@link #refreshJobMemoryAndAllOthers}) + * 3. For one named ML job (via {@link #refreshJobMemory}) + * In all cases a listener informs the caller when the requested updates are complete. + */ +public class MlMemoryTracker implements LocalNodeMasterListener { + + private static final AckedRequest ACKED_REQUEST = new AckedRequest() { + @Override + public TimeValue ackTimeout() { + return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT; + } + + @Override + public TimeValue masterNodeTimeout() { + return AcknowledgedRequest.DEFAULT_ACK_TIMEOUT; + } + }; + + private static final Duration RECENT_UPDATE_THRESHOLD = Duration.ofMinutes(1); + + private final Logger logger = LogManager.getLogger(MlMemoryTracker.class); + private final ConcurrentHashMap memoryRequirementByJob = new ConcurrentHashMap<>(); + private final List> fullRefreshCompletionListeners = new ArrayList<>(); + + private final ThreadPool threadPool; + private final ClusterService clusterService; + private final JobManager jobManager; + private final JobResultsProvider jobResultsProvider; + private volatile boolean isMaster; + private volatile Instant lastUpdateTime; + + public MlMemoryTracker(ClusterService clusterService, ThreadPool threadPool, JobManager jobManager, + JobResultsProvider jobResultsProvider) { + this.threadPool = threadPool; + this.clusterService = clusterService; + this.jobManager = jobManager; + this.jobResultsProvider = jobResultsProvider; + clusterService.addLocalNodeMasterListener(this); + } + + @Override + public void onMaster() { + isMaster = true; + logger.trace("ML memory tracker on master"); + } + + @Override + public void offMaster() { + isMaster = false; + logger.trace("ML memory tracker off master"); + memoryRequirementByJob.clear(); + lastUpdateTime = null; + } + + @Override + public String executorName() { + return MachineLearning.UTILITY_THREAD_POOL_NAME; + } + + /** + * Is the information in this object sufficiently up to date + * for valid allocation decisions to be made using it? + */ + public boolean isRecentlyRefreshed() { + Instant localLastUpdateTime = lastUpdateTime; + return localLastUpdateTime != null && localLastUpdateTime.plus(RECENT_UPDATE_THRESHOLD).isAfter(Instant.now()); + } + + /** + * Get the memory requirement for a job. + * This method only works on the master node. + * @param jobId The job ID. + * @return The memory requirement of the job specified by {@code jobId}, + * or null if it cannot be calculated. + */ + public Long getJobMemoryRequirement(String jobId) { + + if (isMaster == false) { + return null; + } + + Long memoryRequirement = memoryRequirementByJob.get(jobId); + if (memoryRequirement != null) { + return memoryRequirement; + } + + // Fallback for mixed version 6.6+/pre-6.6 cluster - TODO: remove in 7.0 + Job job = MlMetadata.getMlMetadata(clusterService.state()).getJobs().get(jobId); + if (job != null) { + return job.estimateMemoryFootprint(); + } + + return null; + } + + /** + * Remove any memory requirement that is stored for the specified job. + * It doesn't matter if this method is called for a job that doesn't have + * a stored memory requirement. + */ + public void removeJob(String jobId) { + memoryRequirementByJob.remove(jobId); + } + + /** + * Uses a separate thread to refresh the memory requirement for every ML job that has + * a corresponding persistent task. This method only works on the master node. + * @param listener Will be called when the async refresh completes or fails. The + * boolean value indicates whether the cluster state was updated + * with the refresh completion time. (If it was then this will in + * cause the persistent tasks framework to check if any persistent + * tasks are awaiting allocation.) + * @return true if the async refresh is scheduled, and false + * if this is not possible for some reason. + */ + public boolean asyncRefresh(ActionListener listener) { + + if (isMaster) { + try { + ActionListener mlMetaUpdateListener = ActionListener.wrap( + aVoid -> recordUpdateTimeInClusterState(listener), + listener::onFailure + ); + threadPool.executor(executorName()).execute( + () -> refresh(clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE), mlMetaUpdateListener)); + return true; + } catch (EsRejectedExecutionException e) { + logger.debug("Couldn't schedule ML memory update - node might be shutting down", e); + } + } + + return false; + } + + /** + * This refreshes the memory requirement for every ML job that has a corresponding + * persistent task and, in addition, one job that doesn't have a persistent task. + * This method only works on the master node. + * @param jobId The job ID of the job whose memory requirement is to be refreshed + * despite not having a corresponding persistent task. + * @param listener Receives the memory requirement of the job specified by {@code jobId}, + * or null if it cannot be calculated. + */ + public void refreshJobMemoryAndAllOthers(String jobId, ActionListener listener) { + + if (isMaster == false) { + listener.onResponse(null); + return; + } + + PersistentTasksCustomMetaData persistentTasks = clusterService.state().getMetaData().custom(PersistentTasksCustomMetaData.TYPE); + refresh(persistentTasks, ActionListener.wrap(aVoid -> refreshJobMemory(jobId, listener), listener::onFailure)); + } + + /** + * This refreshes the memory requirement for every ML job that has a corresponding persistent task. + * It does NOT remove entries for jobs that no longer have a persistent task, because that would + * lead to a race where a job was opened part way through the refresh. (Instead, entries are removed + * when jobs are deleted.) + */ + void refresh(PersistentTasksCustomMetaData persistentTasks, ActionListener onCompletion) { + + synchronized (fullRefreshCompletionListeners) { + fullRefreshCompletionListeners.add(onCompletion); + if (fullRefreshCompletionListeners.size() > 1) { + // A refresh is already in progress, so don't do another + return; + } + } + + ActionListener refreshComplete = ActionListener.wrap(aVoid -> { + lastUpdateTime = Instant.now(); + synchronized (fullRefreshCompletionListeners) { + assert fullRefreshCompletionListeners.isEmpty() == false; + for (ActionListener listener : fullRefreshCompletionListeners) { + listener.onResponse(null); + } + fullRefreshCompletionListeners.clear(); + } + }, onCompletion::onFailure); + + // persistentTasks will be null if there's never been a persistent task created in this cluster + if (persistentTasks == null) { + refreshComplete.onResponse(null); + } else { + List> mlJobTasks = persistentTasks.tasks().stream() + .filter(task -> MlTasks.JOB_TASK_NAME.equals(task.getTaskName())).collect(Collectors.toList()); + iterateMlJobTasks(mlJobTasks.iterator(), refreshComplete); + } + } + + private void recordUpdateTimeInClusterState(ActionListener listener) { + + clusterService.submitStateUpdateTask("ml-memory-last-update-time", + new AckedClusterStateUpdateTask(ACKED_REQUEST, listener) { + @Override + protected Boolean newResponse(boolean acknowledged) { + return acknowledged; + } + + @Override + public ClusterState execute(ClusterState currentState) { + MlMetadata currentMlMetadata = MlMetadata.getMlMetadata(currentState); + MlMetadata.Builder builder = new MlMetadata.Builder(currentMlMetadata); + builder.setLastMemoryRefreshVersion(currentState.getVersion() + 1); + MlMetadata newMlMetadata = builder.build(); + if (newMlMetadata.equals(currentMlMetadata)) { + // Return same reference if nothing has changed + return currentState; + } else { + ClusterState.Builder newState = ClusterState.builder(currentState); + newState.metaData(MetaData.builder(currentState.getMetaData()).putCustom(MlMetadata.TYPE, newMlMetadata).build()); + return newState.build(); + } + } + }); + } + + private void iterateMlJobTasks(Iterator> iterator, + ActionListener refreshComplete) { + if (iterator.hasNext()) { + OpenJobAction.JobParams jobParams = (OpenJobAction.JobParams) iterator.next().getParams(); + refreshJobMemory(jobParams.getJobId(), + ActionListener.wrap(mem -> iterateMlJobTasks(iterator, refreshComplete), refreshComplete::onFailure)); + } else { + refreshComplete.onResponse(null); + } + } + + /** + * Refresh the memory requirement for a single job. + * This method only works on the master node. + * @param jobId The ID of the job to refresh the memory requirement for. + * @param listener Receives the job's memory requirement, or null + * if it cannot be calculated. + */ + public void refreshJobMemory(String jobId, ActionListener listener) { + if (isMaster == false) { + listener.onResponse(null); + return; + } + + try { + jobResultsProvider.getEstablishedMemoryUsage(jobId, null, null, + establishedModelMemoryBytes -> { + if (establishedModelMemoryBytes <= 0L) { + setJobMemoryToLimit(jobId, listener); + } else { + Long memoryRequirementBytes = establishedModelMemoryBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); + memoryRequirementByJob.put(jobId, memoryRequirementBytes); + listener.onResponse(memoryRequirementBytes); + } + }, + e -> { + logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); + setJobMemoryToLimit(jobId, listener); + } + ); + } catch (Exception e) { + logger.error("[" + jobId + "] failed to calculate job established model memory requirement", e); + setJobMemoryToLimit(jobId, listener); + } + } + + private void setJobMemoryToLimit(String jobId, ActionListener listener) { + jobManager.getJob(jobId, ActionListener.wrap(job -> { + Long memoryLimitMb = job.getAnalysisLimits().getModelMemoryLimit(); + if (memoryLimitMb != null) { + Long memoryRequirementBytes = ByteSizeUnit.MB.toBytes(memoryLimitMb) + Job.PROCESS_MEMORY_OVERHEAD.getBytes(); + memoryRequirementByJob.put(jobId, memoryRequirementBytes); + listener.onResponse(memoryRequirementBytes); + } else { + memoryRequirementByJob.remove(jobId); + listener.onResponse(null); + } + }, e -> { + if (e instanceof ResourceNotFoundException) { + // TODO: does this also happen if the .ml-config index exists but is unavailable? + logger.trace("[{}] job deleted during ML memory update", jobId); + } else { + logger.error("[" + jobId + "] failed to get job during ML memory update", e); + } + memoryRequirementByJob.remove(jobId); + listener.onResponse(null); + })); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java index c7ca2ff805eba..eb58221bf5f35 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetadataTests.java @@ -69,6 +69,9 @@ protected MlMetadata createTestInstance() { builder.putJob(job, false); } } + if (randomBoolean()) { + builder.setLastMemoryRefreshVersion(randomNonNegativeLong()); + } return builder.build(); } @@ -438,8 +441,9 @@ protected MlMetadata mutateInstance(MlMetadata instance) { for (Map.Entry entry : datafeeds.entrySet()) { metadataBuilder.putDatafeed(entry.getValue(), Collections.emptyMap()); } + metadataBuilder.setLastMemoryRefreshVersion(instance.getLastMemoryRefreshVersion()); - switch (between(0, 1)) { + switch (between(0, 2)) { case 0: metadataBuilder.putJob(JobTests.createRandomizedJob(), true); break; @@ -459,6 +463,13 @@ protected MlMetadata mutateInstance(MlMetadata instance) { metadataBuilder.putJob(randomJob, false); metadataBuilder.putDatafeed(datafeedConfig, Collections.emptyMap()); break; + case 2: + if (instance.getLastMemoryRefreshVersion() == null) { + metadataBuilder.setLastMemoryRefreshVersion(randomNonNegativeLong()); + } else { + metadataBuilder.setLastMemoryRefreshVersion(null); + } + break; default: throw new AssertionError("Illegal randomisation branch"); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java index 4a98b380b0929..393fc492f5d63 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportOpenJobActionTests.java @@ -50,7 +50,9 @@ import org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings; import org.elasticsearch.xpack.core.ml.notifications.AuditorField; import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.process.MlMemoryTracker; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; +import org.junit.Before; import java.io.IOException; import java.net.InetAddress; @@ -71,6 +73,14 @@ public class TransportOpenJobActionTests extends ESTestCase { + private MlMemoryTracker memoryTracker; + + @Before + public void setup() { + memoryTracker = mock(MlMemoryTracker.class); + when(memoryTracker.isRecentlyRefreshed()).thenReturn(true); + } + public void testValidate_jobMissing() { expectThrows(ResourceNotFoundException.class, () -> TransportOpenJobAction.validate("job_id2", null)); } @@ -125,7 +135,7 @@ public void testSelectLeastLoadedMlNode_byCount() { jobBuilder.setJobVersion(Version.CURRENT); Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id4", jobBuilder.build(), - cs.build(), 2, 10, 30, logger); + cs.build(), 2, 10, 30, memoryTracker, logger); assertEquals("", result.getExplanation()); assertEquals("_node_id3", result.getExecutorNode()); } @@ -161,7 +171,7 @@ public void testSelectLeastLoadedMlNode_maxCapacity() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id0", new ByteSizeValue(150, ByteSizeUnit.MB)).build(new Date()); Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id0", job, cs.build(), 2, - maxRunningJobsPerNode, 30, logger); + maxRunningJobsPerNode, 30, memoryTracker, logger); assertNull(result.getExecutorNode()); assertTrue(result.getExplanation().contains("because this node is full. Number of opened jobs [" + maxRunningJobsPerNode + "], xpack.ml.max_open_jobs [" + maxRunningJobsPerNode + "]")); @@ -187,7 +197,7 @@ public void testSelectLeastLoadedMlNode_noMlNodes() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id2", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id2", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id2", job, cs.build(), 2, 10, 30, memoryTracker, logger); assertTrue(result.getExplanation().contains("because this node isn't a ml node")); assertNull(result.getExecutorNode()); } @@ -221,7 +231,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id6", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); ClusterState cs = csBuilder.build(); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id6", job, cs, 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id6", job, cs, 2, 10, 30, memoryTracker, logger); assertEquals("_node_id3", result.getExecutorNode()); tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); @@ -231,7 +241,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because OPENING state", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); @@ -242,7 +252,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because stale task", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); @@ -253,7 +263,7 @@ public void testSelectLeastLoadedMlNode_maxConcurrentOpeningJobs() { csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because null state", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); } @@ -291,7 +301,7 @@ public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob() Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id7", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); // Allocation won't be possible if the stale failed job is treated as opening - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id7", job, cs, 2, 10, 30, memoryTracker, logger); assertEquals("_node_id1", result.getExecutorNode()); tasksBuilder = PersistentTasksCustomMetaData.builder(tasks); @@ -301,7 +311,7 @@ public void testSelectLeastLoadedMlNode_concurrentOpeningJobsAndStaleFailedJob() csBuilder = ClusterState.builder(cs); csBuilder.metaData(MetaData.builder(cs.metaData()).putCustom(PersistentTasksCustomMetaData.TYPE, tasks)); cs = csBuilder.build(); - result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id8", job, cs, 2, 10, 30, logger); + result = TransportOpenJobAction.selectLeastLoadedMlNode("job_id8", job, cs, 2, 10, 30, memoryTracker, logger); assertNull("no node selected, because OPENING state", result.getExecutorNode()); assertTrue(result.getExplanation().contains("because node exceeds [2] the maximum number of jobs [2] in opening state")); } @@ -332,7 +342,8 @@ public void testSelectLeastLoadedMlNode_noCompatibleJobTypeNodes() { cs.nodes(nodes); metaData.putCustom(PersistentTasksCustomMetaData.TYPE, tasks); cs.metaData(metaData); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, + memoryTracker, logger); assertThat(result.getExplanation(), containsString("because this node does not support jobs of type [incompatible_type]")); assertNull(result.getExecutorNode()); } @@ -359,7 +370,8 @@ public void testSelectLeastLoadedMlNode_noNodesPriorTo_V_5_5() { Job job = BaseMlIntegTestCase.createFareQuoteJob("job_id7", new ByteSizeValue(2, ByteSizeUnit.MB)).build(new Date()); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("incompatible_type_job", job, cs.build(), 2, 10, 30, + memoryTracker, logger); assertThat(result.getExplanation(), containsString("because this node does not support machine learning jobs")); assertNull(result.getExecutorNode()); } @@ -385,7 +397,8 @@ public void testSelectLeastLoadedMlNode_jobWithRulesButNoNodeMeetsRequiredVersio cs.metaData(metaData); Job job = jobWithRules("job_with_rules"); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, memoryTracker, + logger); assertThat(result.getExplanation(), containsString( "because jobs using custom_rules require a node of version [6.4.0] or higher")); assertNull(result.getExecutorNode()); @@ -412,7 +425,8 @@ public void testSelectLeastLoadedMlNode_jobWithRulesAndNodeMeetsRequiredVersion( cs.metaData(metaData); Job job = jobWithRules("job_with_rules"); - Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, logger); + Assignment result = TransportOpenJobAction.selectLeastLoadedMlNode("job_with_rules", job, cs.build(), 2, 10, 30, memoryTracker, + logger); assertNotNull(result.getExecutorNode()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java index 2e14289da705e..5e4d8fd06030c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/MlDistributedFailureIT.java @@ -8,10 +8,13 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.CheckedRunnable; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentHelper; @@ -31,21 +34,32 @@ import org.elasticsearch.xpack.core.ml.action.PutJobAction; import org.elasticsearch.xpack.core.ml.action.StartDatafeedAction; import org.elasticsearch.xpack.core.ml.action.StopDatafeedAction; +import org.elasticsearch.xpack.core.ml.action.util.QueryPage; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig; import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState; import org.elasticsearch.xpack.core.ml.job.config.Job; import org.elasticsearch.xpack.core.ml.job.config.JobState; import org.elasticsearch.xpack.core.ml.job.process.autodetect.state.DataCounts; +import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase; import java.io.IOException; import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import static org.elasticsearch.persistent.PersistentTasksClusterService.needsReassignment; public class MlDistributedFailureIT extends BaseMlIntegTestCase { + @Override + protected Settings nodeSettings(int nodeOrdinal) { + return Settings.builder().put(super.nodeSettings(nodeOrdinal)) + .put(MachineLearning.CONCURRENT_JOB_ALLOCATIONS.getKey(), 4) + .build(); + } + public void testFailOver() throws Exception { internalCluster().ensureAtLeastNumDataNodes(3); ensureStableClusterOnAllNodes(3); @@ -58,8 +72,6 @@ public void testFailOver() throws Exception { }); } - @TestLogging("org.elasticsearch.xpack.ml.action:DEBUG,org.elasticsearch.xpack.persistent:TRACE," + - "org.elasticsearch.xpack.ml.datafeed:TRACE") public void testLoseDedicatedMasterNode() throws Exception { internalCluster().ensureAtMostNumDataNodes(0); logger.info("Starting dedicated master node..."); @@ -136,12 +148,12 @@ public void testCloseUnassignedJobAndDatafeed() throws Exception { // Job state is opened but the job is not assigned to a node (because we just killed the only ML node) GetJobsStatsAction.Request jobStatsRequest = new GetJobsStatsAction.Request(jobId); GetJobsStatsAction.Response jobStatsResponse = client().execute(GetJobsStatsAction.INSTANCE, jobStatsRequest).actionGet(); - assertEquals(jobStatsResponse.getResponse().results().get(0).getState(), JobState.OPENED); + assertEquals(JobState.OPENED, jobStatsResponse.getResponse().results().get(0).getState()); GetDatafeedsStatsAction.Request datafeedStatsRequest = new GetDatafeedsStatsAction.Request(datafeedId); GetDatafeedsStatsAction.Response datafeedStatsResponse = client().execute(GetDatafeedsStatsAction.INSTANCE, datafeedStatsRequest).actionGet(); - assertEquals(datafeedStatsResponse.getResponse().results().get(0).getDatafeedState(), DatafeedState.STARTED); + assertEquals(DatafeedState.STARTED, datafeedStatsResponse.getResponse().results().get(0).getDatafeedState()); // Can't normal stop an unassigned datafeed StopDatafeedAction.Request stopDatafeedRequest = new StopDatafeedAction.Request(datafeedId); @@ -170,6 +182,73 @@ public void testCloseUnassignedJobAndDatafeed() throws Exception { assertTrue(closeJobResponse.isClosed()); } + @TestLogging("org.elasticsearch.xpack.ml.action:TRACE,org.elasticsearch.xpack.ml.process:TRACE") + public void testJobRelocationIsMemoryAware() throws Exception { + + internalCluster().ensureAtLeastNumDataNodes(1); + ensureStableClusterOnAllNodes(1); + + // Open 4 small jobs. Since there is only 1 node in the cluster they'll have to go on that node. + + setupJobWithoutDatafeed("small1", new ByteSizeValue(2, ByteSizeUnit.MB)); + setupJobWithoutDatafeed("small2", new ByteSizeValue(2, ByteSizeUnit.MB)); + setupJobWithoutDatafeed("small3", new ByteSizeValue(2, ByteSizeUnit.MB)); + setupJobWithoutDatafeed("small4", new ByteSizeValue(2, ByteSizeUnit.MB)); + + // Expand the cluster to 3 nodes. The 4 small jobs will stay on the + // same node because we don't rebalance jobs that are happily running. + + internalCluster().ensureAtLeastNumDataNodes(3); + ensureStableClusterOnAllNodes(3); + + // Open a big job. This should go on a different node to the 4 small ones. + + setupJobWithoutDatafeed("big1", new ByteSizeValue(500, ByteSizeUnit.MB)); + + // Stop the current master node - this should be the one with the 4 small jobs on. + + internalCluster().stopCurrentMasterNode(); + ensureStableClusterOnAllNodes(2); + + // If memory requirements are used to reallocate the 4 small jobs (as we expect) then they should + // all reallocate to the same node, that being the one that doesn't have the big job on. If job counts + // are used to reallocate the small jobs then this implies the fallback allocation mechanism has been + // used in a situation we don't want it to be used in, and at least one of the small jobs will be on + // the same node as the big job. (This all relies on xpack.ml.node_concurrent_job_allocations being set + // to at least 4, which we do in the nodeSettings() method.) + + assertBusy(() -> { + GetJobsStatsAction.Response statsResponse = + client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(MetaData.ALL)).actionGet(); + QueryPage jobStats = statsResponse.getResponse(); + assertNotNull(jobStats); + List smallJobNodes = jobStats.results().stream().filter(s -> s.getJobId().startsWith("small") && s.getNode() != null) + .map(s -> s.getNode().getName()).collect(Collectors.toList()); + List bigJobNodes = jobStats.results().stream().filter(s -> s.getJobId().startsWith("big") && s.getNode() != null) + .map(s -> s.getNode().getName()).collect(Collectors.toList()); + logger.info("small job nodes: " + smallJobNodes + ", big job nodes: " + bigJobNodes); + assertEquals(5, jobStats.count()); + assertEquals(4, smallJobNodes.size()); + assertEquals(1, bigJobNodes.size()); + assertEquals(1L, smallJobNodes.stream().distinct().count()); + assertEquals(1L, bigJobNodes.stream().distinct().count()); + assertNotEquals(smallJobNodes, bigJobNodes); + }); + } + + private void setupJobWithoutDatafeed(String jobId, ByteSizeValue modelMemoryLimit) throws Exception { + Job.Builder job = createFareQuoteJob(jobId, modelMemoryLimit); + PutJobAction.Request putJobRequest = new PutJobAction.Request(job); + client().execute(PutJobAction.INSTANCE, putJobRequest).actionGet(); + + client().execute(OpenJobAction.INSTANCE, new OpenJobAction.Request(job.getId())).actionGet(); + assertBusy(() -> { + GetJobsStatsAction.Response statsResponse = + client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet(); + assertEquals(JobState.OPENED, statsResponse.getResponse().results().get(0).getState()); + }); + } + private void setupJobAndDatafeed(String jobId, String datafeedId) throws Exception { Job.Builder job = createScheduledJob(jobId); PutJobAction.Request putJobRequest = new PutJobAction.Request(job); @@ -183,7 +262,7 @@ private void setupJobAndDatafeed(String jobId, String datafeedId) throws Excepti assertBusy(() -> { GetJobsStatsAction.Response statsResponse = client().execute(GetJobsStatsAction.INSTANCE, new GetJobsStatsAction.Request(job.getId())).actionGet(); - assertEquals(statsResponse.getResponse().results().get(0).getState(), JobState.OPENED); + assertEquals(JobState.OPENED, statsResponse.getResponse().results().get(0).getState()); }); StartDatafeedAction.Request startDatafeedRequest = new StartDatafeedAction.Request(config.getId(), 0L); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java index 87aa3c5b926e3..c4150d633a8f0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TooManyJobsIT.java @@ -123,12 +123,10 @@ public void testLazyNodeValidation() throws Exception { }); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/34084") public void testSingleNode() throws Exception { verifyMaxNumberOfJobsLimit(1, randomIntBetween(1, 100)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/34084") public void testMultipleNodes() throws Exception { verifyMaxNumberOfJobsLimit(3, randomIntBetween(1, 100)); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java new file mode 100644 index 0000000000000..cbba7ffa04972 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/process/MlMemoryTrackerTests.java @@ -0,0 +1,195 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.ml.process; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.AckedClusterStateUpdateTask; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.ByteSizeUnit; +import org.elasticsearch.persistent.PersistentTasksCustomMetaData; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ml.MlMetadata; +import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.OpenJobAction; +import org.elasticsearch.xpack.core.ml.job.config.AnalysisLimits; +import org.elasticsearch.xpack.core.ml.job.config.Job; +import org.elasticsearch.xpack.ml.job.JobManager; +import org.elasticsearch.xpack.ml.job.persistence.JobResultsProvider; +import org.junit.Before; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class MlMemoryTrackerTests extends ESTestCase { + + private ClusterService clusterService; + private ThreadPool threadPool; + private JobManager jobManager; + private JobResultsProvider jobResultsProvider; + private MlMemoryTracker memoryTracker; + + @Before + public void setup() { + + clusterService = mock(ClusterService.class); + threadPool = mock(ThreadPool.class); + ExecutorService executorService = mock(ExecutorService.class); + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Runnable r = (Runnable) invocation.getArguments()[0]; + r.run(); + return null; + }).when(executorService).execute(any(Runnable.class)); + when(threadPool.executor(anyString())).thenReturn(executorService); + jobManager = mock(JobManager.class); + jobResultsProvider = mock(JobResultsProvider.class); + memoryTracker = new MlMemoryTracker(clusterService, threadPool, jobManager, jobResultsProvider); + } + + public void testRefreshAll() { + + boolean isMaster = randomBoolean(); + if (isMaster) { + memoryTracker.onMaster(); + } else { + memoryTracker.offMaster(); + } + + int numMlJobTasks = randomIntBetween(2, 5); + Map> tasks = new HashMap<>(); + for (int i = 1; i <= numMlJobTasks; ++i) { + String jobId = "job" + i; + PersistentTasksCustomMetaData.PersistentTask task = makeTestTask(jobId); + tasks.put(task.getId(), task); + } + PersistentTasksCustomMetaData persistentTasks = new PersistentTasksCustomMetaData(numMlJobTasks, tasks); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Consumer listener = (Consumer) invocation.getArguments()[3]; + listener.accept(randomLongBetween(1000, 1000000)); + return null; + }).when(jobResultsProvider).getEstablishedMemoryUsage(anyString(), any(), any(), any(Consumer.class), any()); + + memoryTracker.refresh(persistentTasks, ActionListener.wrap(aVoid -> {}, ESTestCase::assertNull)); + + if (isMaster) { + for (int i = 1; i <= numMlJobTasks; ++i) { + String jobId = "job" + i; + verify(jobResultsProvider, times(1)).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(), any()); + } + } else { + verify(jobResultsProvider, never()).getEstablishedMemoryUsage(anyString(), any(), any(), any(), any()); + } + } + + public void testRefreshOne() { + + boolean isMaster = randomBoolean(); + if (isMaster) { + memoryTracker.onMaster(); + } else { + memoryTracker.offMaster(); + } + + String jobId = "job"; + boolean haveEstablishedModelMemory = randomBoolean(); + + long modelBytes = 1024 * 1024; + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + Consumer listener = (Consumer) invocation.getArguments()[3]; + listener.accept(haveEstablishedModelMemory ? modelBytes : 0L); + return null; + }).when(jobResultsProvider).getEstablishedMemoryUsage(eq(jobId), any(), any(), any(Consumer.class), any()); + + long modelMemoryLimitMb = 2; + Job job = mock(Job.class); + when(job.getAnalysisLimits()).thenReturn(new AnalysisLimits(modelMemoryLimitMb, 4L)); + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onResponse(job); + return null; + }).when(jobManager).getJob(eq(jobId), any(ActionListener.class)); + + AtomicReference refreshedMemoryRequirement = new AtomicReference<>(); + memoryTracker.refreshJobMemory(jobId, ActionListener.wrap(refreshedMemoryRequirement::set, ESTestCase::assertNull)); + + if (isMaster) { + if (haveEstablishedModelMemory) { + assertEquals(Long.valueOf(modelBytes + Job.PROCESS_MEMORY_OVERHEAD.getBytes()), + memoryTracker.getJobMemoryRequirement(jobId)); + } else { + assertEquals(Long.valueOf(ByteSizeUnit.MB.toBytes(modelMemoryLimitMb) + Job.PROCESS_MEMORY_OVERHEAD.getBytes()), + memoryTracker.getJobMemoryRequirement(jobId)); + } + } else { + assertNull(memoryTracker.getJobMemoryRequirement(jobId)); + } + + assertEquals(memoryTracker.getJobMemoryRequirement(jobId), refreshedMemoryRequirement.get()); + + memoryTracker.removeJob(jobId); + assertNull(memoryTracker.getJobMemoryRequirement(jobId)); + } + + @SuppressWarnings("unchecked") + public void testRecordUpdateTimeInClusterState() { + + boolean isMaster = randomBoolean(); + if (isMaster) { + memoryTracker.onMaster(); + } else { + memoryTracker.offMaster(); + } + + when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE); + + AtomicReference updateVersion = new AtomicReference<>(); + + doAnswer(invocation -> { + AckedClusterStateUpdateTask task = (AckedClusterStateUpdateTask) invocation.getArguments()[1]; + ClusterState currentClusterState = ClusterState.EMPTY_STATE; + ClusterState newClusterState = task.execute(currentClusterState); + assertThat(currentClusterState, not(equalTo(newClusterState))); + MlMetadata newMlMetadata = MlMetadata.getMlMetadata(newClusterState); + updateVersion.set(newMlMetadata.getLastMemoryRefreshVersion()); + task.onAllNodesAcked(null); + return null; + }).when(clusterService).submitStateUpdateTask(anyString(), any(AckedClusterStateUpdateTask.class)); + + memoryTracker.asyncRefresh(ActionListener.wrap(ESTestCase::assertTrue, ESTestCase::assertNull)); + + if (isMaster) { + assertNotNull(updateVersion.get()); + } else { + assertNull(updateVersion.get()); + } + } + + private PersistentTasksCustomMetaData.PersistentTask makeTestTask(String jobId) { + return new PersistentTasksCustomMetaData.PersistentTask<>("job-" + jobId, MlTasks.JOB_TASK_NAME, new OpenJobAction.JobParams(jobId), + 0, PersistentTasksCustomMetaData.INITIAL_ASSIGNMENT); + } +}