Skip to content

Commit

Permalink
adds Final access modifier to public methods, adds logging
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Flores <[email protected]>
  • Loading branch information
brianf-aws committed Oct 22, 2024
1 parent ae35ddb commit 5a28b53
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import java.util.Set;
import java.util.StringJoiner;

import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_KEEP_PREVIOUS_SCORE;
import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_REMOVE_TARGET_FIELD;
import static org.opensearch.neuralsearch.processor.rerank.RerankProcessor.processorRequiresContext;

/**
Expand Down Expand Up @@ -72,9 +74,6 @@ public SearchResponseProcessor create(
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
case BY_FIELD:
boolean DEFAULT_REMOVE_TARGET_FIELD = false;
boolean DEFAULT_KEEP_PREVIOUS_SCORE = false;

String targetField = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.processor.rerank;

import lombok.extern.log4j.Log4j2;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
Expand All @@ -16,6 +17,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -63,12 +65,16 @@
* information stored in a field can be used to improve the relevance of search results
* beyond the initial scoring.
*/
@Log4j2
public class ByFieldRerankProcessor extends RescoringRerankProcessor {

public static final String TARGET_FIELD = "target_field";
public static final String REMOVE_TARGET_FIELD = "remove_target_field";
public static final String KEEP_PREVIOUS_SCORE = "keep_previous_score";

public static final boolean DEFAULT_REMOVE_TARGET_FIELD = false;
public static final boolean DEFAULT_KEEP_PREVIOUS_SCORE = false;

protected final String targetField;
protected final boolean removeTargetField;
protected final boolean keepPreviousScore;
Expand All @@ -86,12 +92,12 @@ public class ByFieldRerankProcessor extends RescoringRerankProcessor {
* @param contextSourceFetchers Context from some source and puts it in a map for a reranking processor to use <b> (Unused in ByFieldRerankProcessor)</b>
*/
public ByFieldRerankProcessor(
String description,
String tag,
boolean ignoreFailure,
String targetField,
boolean removeTargetField,
boolean keepPreviousScore,
final String description,
final String tag,
final boolean ignoreFailure,
final String targetField,
final boolean removeTargetField,
final boolean keepPreviousScore,
final List<ContextSourceFetcher> contextSourceFetchers
) {
super(RerankType.BY_FIELD, description, tag, ignoreFailure, contextSourceFetchers);
Expand All @@ -101,7 +107,11 @@ public ByFieldRerankProcessor(
}

@Override
public void rescoreSearchResponse(SearchResponse response, Map<String, Object> rerankingContext, ActionListener<List<Float>> listener) {
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
SearchHit[] searchHits = response.getHits().getHits();

SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator;
Expand Down Expand Up @@ -131,6 +141,7 @@ public void rescoreSearchResponse(SearchResponse response, Map<String, Object> r
BytesReference sourceMapAsBytes = BytesReference.bytes(builder.map(sourceAsMap));
hit.sourceRef(sourceMapAsBytes);
} catch (IOException e) {
log.error(e.getMessage());
listener.onFailure(new RuntimeException(e));
return;
}
Expand All @@ -149,20 +160,29 @@ public void rescoreSearchResponse(SearchResponse response, Map<String, Object> r
* </ul>
* @param hit A search hit to validate
*/
public void byFieldSearchHitValidator(SearchHit hit) {
public void byFieldSearchHitValidator(final SearchHit hit) {
if (!hit.hasSource()) {
throw new IllegalArgumentException("There is no source field to be able to perform rerank on hit [" + hit.docId() + "]");
log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId()));
throw new IllegalArgumentException(
String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId())
);
}

Map<String, Object> sourceMap = hit.getSourceAsMap();
if (!mappingExistsInSource(sourceMap, targetField)) {
throw new IllegalArgumentException("The field to rerank [" + targetField + "] is not found at hit [" + hit.docId() + "]");
log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%d]", targetField, hit.docId()));

throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%d]", hit.docId()));
}

Optional<Object> val = getValueFromSource(sourceMap, targetField);

if (!(val.get() instanceof Number)) {
throw new IllegalArgumentException("The field mapping to rerank [" + targetField + ": " + val.get() + "] is a not Numerical");
log.error(String.format(Locale.ROOT, "The field mapping to rerank [%s: %s] is not Numerical", targetField, val.orElse(null)));

throw new IllegalArgumentException(
String.format(Locale.ROOT, "The field mapping to rerank by [%s] is not Numerical", val.orElse(null))
);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public interface SearchHitValidator {
* @param hit The specific SearchHit were the invalidation occurred
* @throws IllegalArgumentException if the validation for the hit fails
*/
void validate(SearchHit hit) throws IllegalArgumentException;
void validate(final SearchHit hit) throws IllegalArgumentException;
}

/**
Expand All @@ -51,9 +51,9 @@ public interface SearchHitValidator {
* @return The status indicating that the SearchHits are in correct form to perform the Rerank
*/
public static boolean validateRerankCriteria(
SearchHit[] searchHits,
SearchHitValidator validator,
ActionListener<List<Float>> listener
final SearchHit[] searchHits,
final SearchHitValidator validator,
final ActionListener<List<Float>> listener
) {
for (SearchHit hit : searchHits) {
try {
Expand All @@ -75,7 +75,7 @@ public static boolean validateRerankCriteria(
* @param targetField the path to take to get the score to replace by
* @return The numerical score found using the <code>target_field</code>
*/
public static float getScoreFromSourceMap(Map<String, Object> sourceAsMap, String targetField) {
public static float getScoreFromSourceMap(final Map<String, Object> sourceAsMap, final String targetField) {
Object val = getValueFromSource(sourceAsMap, targetField).get();
return ((Number) val).floatValue();
}
Expand All @@ -98,7 +98,7 @@ public static float getScoreFromSourceMap(Map<String, Object> sourceAsMap, Strin
* @param sourceAsMap the map of maps that contains the <code>targetField</code>
* @param targetField The path to take to remove the targetField
*/
public static void removeTargetFieldFromSource(Map<String, Object> sourceAsMap, String targetField) {
public static void removeTargetFieldFromSource(final Map<String, Object> sourceAsMap, final String targetField) {
Stack<Tuple<Map<String, Object>, String>> parentMapChildrenKeyTupleStack = new Stack<>();
String[] keys = targetField.split("\\.");

Expand Down Expand Up @@ -146,7 +146,7 @@ public static void removeTargetFieldFromSource(Map<String, Object> sourceAsMap,
* @param targetField The path to take to get the desired mapping
* @return A possible result within an optional
*/
public static Optional<Object> getValueFromSource(Map<String, Object> sourceAsMap, String targetField) {
public static Optional<Object> getValueFromSource(final Map<String, Object> sourceAsMap, final String targetField) {
String[] keys = targetField.split("\\.");
Optional<Object> currentValue = Optional.of(sourceAsMap);

Expand Down Expand Up @@ -176,7 +176,7 @@ public static Optional<Object> getValueFromSource(Map<String, Object> sourceAsMa
* @param pathToValue A string of the form key[.key] indicating what keys to apply to the sourceMap
* @return Whether the mapping using the pathToValue exists
*/
public static boolean mappingExistsInSource(Map<String, Object> sourceAsMap, String pathToValue) {
public static boolean mappingExistsInSource(final Map<String, Object> sourceAsMap, final String pathToValue) {
return getValueFromSource(sourceAsMap, pathToValue).isPresent();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ public void testRerank_throwsExceptionOnMappingNotExistingInSource_WhenSearchRes
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals("The field to rerank [" + targetField + "] is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);
}

Expand Down Expand Up @@ -947,7 +947,7 @@ public void testRerank_throwsExceptionOnHavingEmptyMapping_WhenTargetFieldHasNul
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals("The field to rerank [" + targetField + "] is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage());
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);
}

Expand Down Expand Up @@ -985,10 +985,7 @@ public void testRerank_throwsExceptionOnHavingNonNumericValue_WhenTargetFieldHas
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener, times(1)).onFailure(argumentCaptor.capture());

assertEquals(
"The field mapping to rerank [" + targetField + ": " + "hello world" + "] is a not Numerical",
argumentCaptor.getValue().getMessage()
);
assertEquals("The field mapping to rerank by [hello world] is not Numerical", argumentCaptor.getValue().getMessage());
assert (argumentCaptor.getValue() instanceof IllegalArgumentException);

}
Expand Down

0 comments on commit 5a28b53

Please sign in to comment.