Skip to content

Commit

Permalink
Allow users to specify similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-buttner committed Mar 19, 2024
1 parent edc45f5 commit e7cb97d
Show file tree
Hide file tree
Showing 17 changed files with 403 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public CohereEmbeddingsAction(Sender sender, CohereEmbeddingsModel model) {
Objects.requireNonNull(model);
this.sender = Objects.requireNonNull(sender);
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
model.getServiceSettings().getCommonSettings().getUri(),
model.getServiceSettings().getCommonSettings().uri(),
"Cohere embeddings"
);
requestCreator = new CohereEmbeddingsExecutableRequestCreator(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private static ResponseHandler createEmbeddingsHandler() {

public CohereEmbeddingsExecutableRequestCreator(CohereEmbeddingsModel model) {
this.model = Objects.requireNonNull(model);
account = new CohereAccount(this.model.getServiceSettings().getCommonSettings().getUri(), this.model.getSecretSettings().apiKey());
account = new CohereAccount(this.model.getServiceSettings().getCommonSettings().uri(), this.model.getSecretSettings().apiKey());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public CohereEmbeddingsRequest(CohereAccount account, List<String> input, Cohere
this.input = Objects.requireNonNull(input);
uri = buildUri(this.account.url(), "Cohere", CohereEmbeddingsRequest::buildDefaultUri);
taskSettings = embeddingsModel.getTaskSettings();
model = embeddingsModel.getServiceSettings().getCommonSettings().getModelId();
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
inferenceEntityId = embeddingsModel.getInferenceEntityId();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,16 @@ public void checkModelConfig(Model model, ActionListener<Model> listener) {
}

private CohereEmbeddingsModel updateModelWithEmbeddingDetails(CohereEmbeddingsModel model, int embeddingSize) {
var similarityFromModel = model.getServiceSettings().similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

CohereEmbeddingsServiceSettings serviceSettings = new CohereEmbeddingsServiceSettings(
new CohereServiceSettings(
model.getServiceSettings().getCommonSettings().getUri(),
SimilarityMeasure.DOT_PRODUCT,
model.getServiceSettings().getCommonSettings().uri(),
similarityToUse,
embeddingSize,
model.getServiceSettings().getCommonSettings().getMaxInputTokens(),
model.getServiceSettings().getCommonSettings().getModelId()
model.getServiceSettings().getCommonSettings().maxInputTokens(),
model.getServiceSettings().getCommonSettings().modelId()
),
model.getServiceSettings().getEmbeddingType()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ public static CohereServiceSettings fromMap(Map<String, Object> map, Configurati
throw validationException;
}

return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, getModelId(oldModelId, modelId));
return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId));
}

private static String getModelId(@Nullable String model, @Nullable String modelId) {
private static String modelId(@Nullable String model, @Nullable String modelId) {
return modelId != null ? modelId : model;
}

Expand Down Expand Up @@ -110,23 +110,25 @@ public CohereServiceSettings(StreamInput in) throws IOException {
modelId = in.readOptionalString();
}

public URI getUri() {
public URI uri() {
return uri;
}

public SimilarityMeasure getSimilarity() {
@Override
public SimilarityMeasure similarity() {
return similarity;
}

public Integer getDimensions() {
@Override
public Integer dimensions() {
return dimensions;
}

public Integer getMaxInputTokens() {
public Integer maxInputTokens() {
return maxInputTokens;
}

public String getModelId() {
public String modelId() {
return modelId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.SimilarityMeasure;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
Expand Down Expand Up @@ -93,6 +94,16 @@ public CohereServiceSettings getCommonSettings() {
return commonSettings;
}

@Override
public SimilarityMeasure similarity() {
return commonSettings.similarity();
}

@Override
public Integer dimensions() {
return commonSettings.dimensions();
}

public CohereEmbeddingType getEmbeddingType() {
return embeddingType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void checkModelConfig(Model model, ActionListener<Model> listener) {
private static HuggingFaceEmbeddingsModel updateModelWithEmbeddingDetails(HuggingFaceEmbeddingsModel model, int embeddingSize) {
var serviceSettings = new HuggingFaceServiceSettings(
model.getServiceSettings().uri(),
null, // Similarity measure is unknown
model.getServiceSettings().similarity(), // we don't know the similarity but use whatever the user specified
embeddingSize,
model.getTokenLimit()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,12 @@ public URI uri() {
return uri;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public Integer dimensions() {
return dimensions;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,14 @@ private OpenAiEmbeddingsModel updateModelWithEmbeddingDetails(OpenAiEmbeddingsMo
);
}

var similarityFromModel = model.getServiceSettings().similarity();
var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel;

OpenAiEmbeddingsServiceSettings serviceSettings = new OpenAiEmbeddingsServiceSettings(
model.getServiceSettings().modelId(),
model.getServiceSettings().uri(),
model.getServiceSettings().organizationId(),
SimilarityMeasure.DOT_PRODUCT,
similarityToUse,
embeddingSize,
model.getServiceSettings().maxInputTokens(),
model.getServiceSettings().dimensionsSetByUser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,12 @@ public String organizationId() {
return organizationId;
}

@Override
public SimilarityMeasure similarity() {
return similarity;
}

@Override
public Integer dimensions() {
return dimensions;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public void testCreateRequest_TruncateNone() throws URISyntaxException, IOExcept
}

public static CohereEmbeddingsRequest createRequest(List<String> input, CohereEmbeddingsModel model) throws URISyntaxException {
var account = new CohereAccount(model.getServiceSettings().getCommonSettings().getUri(), model.getSecretSettings().apiKey());
var account = new CohereAccount(model.getServiceSettings().getCommonSettings().uri(), model.getSecretSettings().apiKey());
return new CohereEmbeddingsRequest(account, input, model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public void testFromMap_PrefersModelId_OverModel() {

public void testFromMap_MissingUrl_DoesNotThrowException() {
var serviceSettings = CohereServiceSettings.fromMap(new HashMap<>(Map.of()), ConfigurationParseContext.PERSISTENT);
assertNull(serviceSettings.getUri());
assertNull(serviceSettings.uri());
}

public void testFromMap_EmptyUrl_ThrowsError() {
Expand Down
Loading

0 comments on commit e7cb97d

Please sign in to comment.