Skip to content

Commit

Permalink
Testing timeout from params
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Apr 22, 2024
1 parent 856aa84 commit dcc0b2a
Showing 1 changed file with 33 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.test.rest.RestActionTestCase;
Expand All @@ -17,19 +18,22 @@
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults;
import org.junit.Before;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;

public class RestInferenceActionTests extends RestActionTestCase {

@Before
public void setUpAction() {
controller().registerHandler(new RestInferenceAction());
}

public void test() {
public void testUsesDefaultTimeout() {
SetOnce<Boolean> executeCalled = new SetOnce<>();
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
assertThat(actionRequest, instanceOf(InferenceAction.Request.class));
Expand All @@ -38,16 +42,41 @@ public void test() {
assertThat(request.getInferenceTimeout(), is(InferenceAction.Request.DEFAULT_TIMEOUT));

executeCalled.set(true);
return new InferenceAction.Response(
new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1))))
);
return createResponse();
}));

RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath("_inference/test")
.withContent(new BytesArray("{}"), XContentType.JSON)
.build();
dispatchRequest(inferenceRequest);
assertThat(executeCalled.get(), equalTo(true));
}

public void testUses3SecondTimeoutFromParams() {
SetOnce<Boolean> executeCalled = new SetOnce<>();
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
assertThat(actionRequest, instanceOf(InferenceAction.Request.class));

var request = (InferenceAction.Request) actionRequest;
assertThat(request.getInferenceTimeout(), is(TimeValue.timeValueSeconds(3)));

executeCalled.set(true);
return createResponse();
}));

RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST)
.withPath("_inference/test")
.withParams(new HashMap<>(Map.of("timeout", "3s")))
.withContent(new BytesArray("{}"), XContentType.JSON)
.build();
dispatchRequest(inferenceRequest);
assertThat(executeCalled.get(), equalTo(true));
}

private static InferenceAction.Response createResponse() {
return new InferenceAction.Response(
new TextEmbeddingByteResults(List.of(new TextEmbeddingByteResults.Embedding(List.of((byte) -1))))
);
}
}

0 comments on commit dcc0b2a

Please sign in to comment.