Skip to content

Commit

Permalink
Refactor InferenceProcessorInfoExtractor to avoid ConfigurationUtils (e…
Browse files Browse the repository at this point in the history
  • Loading branch information
joegallo committed Oct 23, 2024
1 parent 4506be6 commit fc48440
Showing 1 changed file with 21 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import org.apache.lucene.util.Counter;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.ingest.ConfigurationUtils;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.transport.Transports;

import java.util.HashMap;
Expand All @@ -24,6 +22,7 @@
import java.util.function.Consumer;

import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD;
import static org.elasticsearch.ingest.Pipeline.ON_FAILURE_KEY;
import static org.elasticsearch.ingest.Pipeline.PROCESSORS_KEY;

/**
Expand Down Expand Up @@ -53,16 +52,10 @@ public static int countInferenceProcessors(ClusterState state) {
Counter counter = Counter.newCounter();
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = (List<Map<String, Object>>) configMap.get(PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(
entry.getKey(),
pipelineId,
(Map<String, Object>) entry.getValue(),
pam -> counter.addAndGet(1),
0
);
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> counter.addAndGet(1), 0);
}
}
});
Expand All @@ -73,7 +66,6 @@ public static int countInferenceProcessors(ClusterState state) {
* @param ingestMetadata The ingestMetadata of current ClusterState
* @return The set of model IDs referenced by inference processors
*/
@SuppressWarnings("unchecked")
public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata ingestMetadata) {
if (ingestMetadata == null) {
return Set.of();
Expand All @@ -82,7 +74,7 @@ public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata inge
Set<String> modelIds = new LinkedHashSet<>();
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> modelIds.add(pam.modelIdOrAlias()), 0);
Expand All @@ -96,7 +88,6 @@ public static Set<String> getModelIdsFromInferenceProcessors(IngestMetadata inge
* @param state Current cluster state
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
@SuppressWarnings("unchecked")
public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Map<String, Set<String>> pipelineIdsByModelIds = new HashMap<>();
Expand All @@ -110,7 +101,7 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
}
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> {
Expand All @@ -128,7 +119,6 @@ public static Map<String, Set<String>> pipelineIdsByResource(ClusterState state,
* @param state Current {@link ClusterState}
* @return a map from Model or Deployment IDs or Aliases to each pipeline referencing them.
*/
@SuppressWarnings("unchecked")
public static Set<String> pipelineIdsForResource(ClusterState state, Set<String> ids) {
assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures");
Set<String> pipelineIds = new HashSet<>();
Expand All @@ -142,7 +132,7 @@ public static Set<String> pipelineIdsForResource(ClusterState state, Set<String>
}
ingestMetadata.getPipelines().forEach((pipelineId, configuration) -> {
Map<String, Object> configMap = configuration.getConfigAsMap();
List<Map<String, Object>> processorConfigs = ConfigurationUtils.readList(null, null, configMap, PROCESSORS_KEY);
List<Map<String, Object>> processorConfigs = readList(configMap, PROCESSORS_KEY);
for (Map<String, Object> processorConfigWithKey : processorConfigs) {
for (Map.Entry<String, Object> entry : processorConfigWithKey.entrySet()) {
addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), pam -> {
Expand Down Expand Up @@ -188,21 +178,16 @@ private static void addModelsAndPipelines(
addModelsAndPipelines(
innerProcessorWithName.getKey(),
pipelineId,
(Map<String, Object>) innerProcessorWithName.getValue(),
innerProcessorWithName.getValue(),
handler,
level + 1
);
}
}
return;
}
if (processorDefinition instanceof Map<?, ?> definitionMap && definitionMap.containsKey(Pipeline.ON_FAILURE_KEY)) {
List<Map<String, Object>> onFailureConfigs = ConfigurationUtils.readList(
null,
null,
(Map<String, Object>) definitionMap,
Pipeline.ON_FAILURE_KEY
);
if (processorDefinition instanceof Map<?, ?> definitionMap && definitionMap.containsKey(ON_FAILURE_KEY)) {
List<Map<String, Object>> onFailureConfigs = readList(definitionMap, ON_FAILURE_KEY);
onFailureConfigs.stream()
.flatMap(map -> map.entrySet().stream())
.forEach(entry -> addModelsAndPipelines(entry.getKey(), pipelineId, entry.getValue(), handler, level + 1));
Expand All @@ -211,4 +196,16 @@ private static void addModelsAndPipelines(

private record PipelineAndModel(String pipelineId, String modelIdOrAlias) {}

/**
* A local alternative to ConfigurationUtils.readList(...) that reads list properties out of the processor configuration map,
* but doesn't rely on mutating the configuration map.
*/
@SuppressWarnings("unchecked")
private static List<Map<String, Object>> readList(Map<?, ?> processorConfig, String key) {
Object val = processorConfig.get(key);
if (val == null) {
throw new IllegalArgumentException("Missing required property [" + key + "]");
}
return (List<Map<String, Object>>) val;
}
}

0 comments on commit fc48440

Please sign in to comment.