Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guardrails for remote model input and output #2209

Merged
merged 8 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ public class CommonValue {
+ MLModel.CONNECTOR_FIELD
+ "\": {" + ML_CONNECTOR_INDEX_FIELDS + " }\n},"
+ USER_FIELD_MAPPING
+ " }\n"
+ " },\n"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you added comma here.

+ " \""
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
+ MLModel.GUARDRAILS_FIELD
+ "\" : {\"type\": \"flat_object\"},\n"
jngz-es marked this conversation as resolved.
Show resolved Hide resolved
+ "}";

public static final String ML_TASK_INDEX_MAPPING = "{\n"
Expand Down
24 changes: 23 additions & 1 deletion common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.controller.MLRateLimiter;
import org.opensearch.ml.common.model.MLModelFormat;
Expand Down Expand Up @@ -84,6 +85,7 @@ public class MLModel implements ToXContentObject {
public static final String IS_HIDDEN_FIELD = "is_hidden";
public static final String CONNECTOR_FIELD = "connector";
public static final String CONNECTOR_ID_FIELD = "connector_id";
public static final String GUARDRAILS_FIELD = "guardrails";

private String name;
private String modelGroupId;
Expand Down Expand Up @@ -127,6 +129,7 @@ public class MLModel implements ToXContentObject {
@Setter
private Connector connector;
private String connectorId;
private Guardrails guardrails;

@Builder(toBuilder = true)
public MLModel(String name,
Expand Down Expand Up @@ -158,7 +161,8 @@ public MLModel(String name,
boolean deployToAllNodes,
Boolean isHidden,
Connector connector,
String connectorId) {
String connectorId,
Guardrails guardrails) {
this.name = name;
this.modelGroupId = modelGroupId;
this.algorithm = algorithm;
Expand Down Expand Up @@ -190,6 +194,7 @@ public MLModel(String name,
this.isHidden = isHidden;
this.connector = connector;
this.connectorId = connectorId;
this.guardrails = guardrails;
}

public MLModel(StreamInput input) throws IOException {
Expand Down Expand Up @@ -243,6 +248,9 @@ public MLModel(StreamInput input) throws IOException {
connector = Connector.fromStream(input);
}
connectorId = input.readOptionalString();
if (input.readBoolean()) {
jngz-es marked this conversation as resolved.
Show resolved Hide resolved
this.guardrails = new Guardrails(input);
}
}
}

Expand Down Expand Up @@ -308,6 +316,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeBoolean(false);
}
out.writeOptionalString(connectorId);
if (guardrails != null) {
out.writeBoolean(true);
guardrails.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -406,6 +420,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
if (connectorId != null) {
builder.field(CONNECTOR_ID_FIELD, connectorId);
}
if (guardrails != null) {
builder.field(GUARDRAILS_FIELD, guardrails);
}
builder.endObject();
return builder;
}
Expand Down Expand Up @@ -448,6 +465,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
boolean isHidden = false;
Connector connector = null;
String connectorId = null;
Guardrails guardrails = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand Down Expand Up @@ -571,6 +589,9 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
case LAST_UNDEPLOYED_TIME_FIELD:
lastUndeployedTime = Instant.ofEpochMilli(parser.longValue());
break;
case GUARDRAILS_FIELD:
guardrails = Guardrails.parse(parser);
break;
default:
parser.skipChildren();
break;
Expand Down Expand Up @@ -608,6 +629,7 @@ public static MLModel parse(XContentParser parser, String algorithmName) throws
.isHidden(isHidden)
.connector(connector)
.connectorId(connectorId)
.guardrails(guardrails)
.build();
}

Expand Down
105 changes: 105 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrail.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved
ylwu-amzn marked this conversation as resolved.
Show resolved Hide resolved

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class Guardrail implements ToXContentObject {
public static final String STOP_WORDS_FIELD = "stop_words";
public static final String REGEX_FIELD = "regex";
Comment on lines +25 to +27
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guardrail should be a super class and stop words is a sub class for a type of guardrail and same with regex being a sub class of type guardrail. This would be easier to integrate new types of guardrails and create a framework around it.


private List<StopWords> stopWords;
private String[] regex;

@Builder(toBuilder = true)
public Guardrail(List<StopWords> stopWords, String[] regex) {
this.stopWords = stopWords;
this.regex = regex;
}

public Guardrail(StreamInput input) throws IOException {
if (input.readBoolean()) {
stopWords = new ArrayList<>();
int size = input.readInt();
for (int i=0; i<size; i++) {
stopWords.add(new StopWords(input));
}
}
regex = input.readStringArray();
}

public void writeTo(StreamOutput out) throws IOException {
if (stopWords != null && stopWords.size() > 0) {
out.writeBoolean(true);
out.writeInt(stopWords.size());
for (StopWords e : stopWords) {
e.writeTo(out);
}
} else {
out.writeBoolean(false);
}
out.writeStringArray(regex);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (stopWords != null && stopWords.size() > 0) {
builder.field(STOP_WORDS_FIELD, stopWords);
}
if (regex != null) {
builder.field(REGEX_FIELD, regex);
}
builder.endObject();
return builder;
}

public static Guardrail parse(XContentParser parser) throws IOException {
List<StopWords> stopWords = null;
String[] regex = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case STOP_WORDS_FIELD:
stopWords = new ArrayList<>();
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
stopWords.add(StopWords.parse(parser));
}
break;
case REGEX_FIELD:
regex = parser.list().toArray(new String[0]);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrail.builder()
.stopWords(stopWords)
.regex(regex)
.build();
}
}
125 changes: 125 additions & 0 deletions common/src/main/java/org/opensearch/ml/common/model/Guardrails.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.model;

import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;

import java.io.IOException;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;

@EqualsAndHashCode
@Getter
public class Guardrails implements ToXContentObject {
public static final String TYPE_FIELD = "type";
public static final String ENGLISH_DETECTION_ENABLED_FIELD = "english_detection_enabled";
jngz-es marked this conversation as resolved.
Show resolved Hide resolved
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";

private String type;
jngz-es marked this conversation as resolved.
Show resolved Hide resolved
private Boolean engDetectionEnabled;
private Guardrail inputGuardrail;
private Guardrail outputGuardrail;

@Builder(toBuilder = true)
public Guardrails(String type, Boolean engDetectionEnabled, Guardrail inputGuardrail, Guardrail outputGuardrail) {
this.type = type;
this.engDetectionEnabled = engDetectionEnabled;
this.inputGuardrail = inputGuardrail;
this.outputGuardrail = outputGuardrail;
}

public Guardrails(StreamInput input) throws IOException {
type = input.readString();
engDetectionEnabled = input.readBoolean();
if (input.readBoolean()) {
inputGuardrail = new Guardrail(input);
}
if (input.readBoolean()) {
outputGuardrail = new Guardrail(input);
}
}

public void writeTo(StreamOutput out) throws IOException {
out.writeString(type);
out.writeBoolean(engDetectionEnabled);
if (inputGuardrail != null) {
out.writeBoolean(true);
inputGuardrail.writeTo(out);
} else {
out.writeBoolean(false);
}
if (outputGuardrail != null) {
out.writeBoolean(true);
outputGuardrail.writeTo(out);
} else {
out.writeBoolean(false);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (type != null) {
builder.field(TYPE_FIELD, type);
}
if (engDetectionEnabled != null) {
builder.field(ENGLISH_DETECTION_ENABLED_FIELD, engDetectionEnabled);
}
if (inputGuardrail != null) {
builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail);
}
if (outputGuardrail != null) {
builder.field(OUTPUT_GUARDRAIL_FIELD, outputGuardrail);
}
builder.endObject();
return builder;
}

public static Guardrails parse(XContentParser parser) throws IOException {
String type = null;
Boolean engDetectionEnabled = null;
Guardrail inputGuardrail = null;
Guardrail outputGuardrail = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
parser.nextToken();

switch (fieldName) {
case TYPE_FIELD:
type = parser.text();
break;
case ENGLISH_DETECTION_ENABLED_FIELD:
engDetectionEnabled = parser.booleanValue();
break;
case INPUT_GUARDRAIL_FIELD:
inputGuardrail = Guardrail.parse(parser);
break;
case OUTPUT_GUARDRAIL_FIELD:
outputGuardrail = Guardrail.parse(parser);
break;
default:
parser.skipChildren();
break;
}
}
return Guardrails.builder()
.type(type)
.engDetectionEnabled(engDetectionEnabled)
.inputGuardrail(inputGuardrail)
.outputGuardrail(outputGuardrail)
.build();
}
}
Loading
Loading