Skip to content

Commit

Permalink
Add vector embedding function.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Sep 27, 2024
1 parent 45409ef commit 289e72d
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ public final class HALYARD implements Vocabulary {
public final static IRI DATASET_IRI_FUNCTION = SVF.createIRI(NAMESPACE, "datasetIRI");
public final static IRI PARALLEL_SPLIT_FUNCTION = SVF.createIRI(NAMESPACE, "forkAndFilterBy");

public final static IRI VECTOR_EMBEDDING_FUNCTION = SVF.createIRI(NAMESPACE, "vectorEmbedding");

public final static IRI DATA_URL_FUNCTION = SVF.createIRI(NAMESPACE, "dataURL");

public final static IRI GET_FUNCTION = SVF.createIRI(NAMESPACE, "get");
Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
<httpclient.version>4.5.14</httpclient.version>
<saxon.version>11.5</saxon.version>
<xmlresolver.version>4.6.4</xmlresolver.version>
<langchain4j.version>0.35.0</langchain4j.version>
<slf4j.version>1.7.36</slf4j.version>
<logback.version>1.2.12</logback.version>
<lz4.version>1.8.0</lz4.version>
Expand Down
7 changes: 6 additions & 1 deletion sail/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<version>0.35.0</version>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>

<dependency>
Expand Down
6 changes: 5 additions & 1 deletion sail/src/main/java/com/msd/gin/halyard/sail/HBaseSail.java
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,11 @@ public void init() throws SailException {
Map<String, String> qhConfig = conf.getPropsWithPrefix(EvaluationConfig.QUERY_HELPERS_PREFIX);
for (Map.Entry<String, String> qhEntry : qhConfig.entrySet()) {
QueryHelperProvider<?> qhp = queryHelperProviderRegistry.get(qhEntry.getKey()).orElseThrow(() -> new SailException(String.format("No %s registered for %s", QueryHelperProvider.class.getName(), qhEntry.getKey())));
queryHelpers.put(qhp.getQueryHelperClass(), qhp.createQueryHelper(conf.getPropsWithPrefix(qhEntry.getKey() + ".")));
try {
queryHelpers.put(qhp.getQueryHelperClass(), qhp.createQueryHelper(conf.getPropsWithPrefix(qhEntry.getKey() + ".")));
} catch (Exception e) {
throw new SailException(e);
}
}

if (esSettings != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
public interface QueryHelperProvider<T> {
Class<T> getQueryHelperClass();

T createQueryHelper(Map<String, String> config);
T createQueryHelper(Map<String, String> config) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import com.msd.gin.halyard.sail.QueryHelperProvider;

import java.lang.reflect.Method;
import java.time.Duration;
import java.util.Collections;
import java.util.Map;

import org.kohsuke.MetaInfServices;

import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.ollama.OllamaEmbeddingModel;

@MetaInfServices(QueryHelperProvider.class)
public class EmbeddingModelQueryHelperProvider implements QueryHelperProvider<EmbeddingModel> {
Expand All @@ -20,14 +19,43 @@ public Class<EmbeddingModel> getQueryHelperClass() {
}

@Override
public EmbeddingModel createQueryHelper(Map<String, String> config) {
String url = config.get("url");
String model = config.get("model");
Duration timeoutMillis = Duration.ofMillis(Long.parseLong(config.getOrDefault("timeoutMillis", "60000")));
int maxRetries = Integer.parseInt(config.getOrDefault("maxRetries", "3"));
boolean logRequests = Boolean.parseBoolean(config.getOrDefault("log.requests", "false"));
boolean logResponses = Boolean.parseBoolean(config.getOrDefault("log.responses", "false"));
return new OllamaEmbeddingModel(url, model, timeoutMillis, maxRetries, logRequests, logResponses, Collections.emptyMap());
public EmbeddingModel createQueryHelper(Map<String, String> config) throws Exception {
String modelClassName = config.get("model.class");
Class<?> modelClass = Class.forName(modelClassName);
Object builder = modelClass.getMethod("builder").invoke(null);
Class<?> builderClass = builder.getClass();
Method buildMethod = builderClass.getMethod("build");
for (Method m : builderClass.getMethods()) {
if (m.getParameterCount() == 1) {
String key = m.getName();
String value = config.get(key);
if (value != null) {
m.invoke(builder, convert(value, m.getParameterTypes()[0]));
}
}
}
return (EmbeddingModel) buildMethod.invoke(builder);
}

private static Object convert(String value, Class<?> targetType) {
if (targetType == String.class) {
return value;
} else if (targetType == Integer.class) {
return Integer.parseInt(value);
} else if (targetType == Boolean.class) {
return Boolean.parseBoolean(value);
} else if (targetType == Duration.class) {
if (value.endsWith("ms")) {
return Duration.ofMillis(Long.parseLong(value.substring(0, value.length() - "ms".length())));
} else if (value.endsWith("s")) {
return Duration.ofSeconds(Long.parseLong(value.substring(0, value.length() - "s".length())));
} else if (value.endsWith("min")) {
return Duration.ofMinutes(Long.parseLong(value.substring(0, value.length() - "min".length())));
} else {
throw new IllegalArgumentException(String.format("Unsupported duration value: %s", value));
}
} else {
throw new IllegalArgumentException(String.format("Unsupported type: %s", targetType.getName()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.msd.gin.halyard.sail.model.embedding.function;

import com.msd.gin.halyard.model.ArrayLiteral;
import com.msd.gin.halyard.model.vocabulary.HALYARD;
import com.msd.gin.halyard.query.algebra.evaluation.ExtendedTripleSource;

import org.eclipse.rdf4j.model.Literal;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.model.ValueFactory;
import org.eclipse.rdf4j.query.QueryEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.TripleSource;
import org.eclipse.rdf4j.query.algebra.evaluation.ValueExprEvaluationException;
import org.eclipse.rdf4j.query.algebra.evaluation.function.Function;
import org.kohsuke.MetaInfServices;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;

@MetaInfServices(Function.class)
public class VectorEmbedding implements Function {

@Override
public String getURI() {
return HALYARD.VECTOR_EMBEDDING_FUNCTION.stringValue();
}

@Override
public Value evaluate(ValueFactory valueFactory, Value... args) throws ValueExprEvaluationException {
throw new UnsupportedOperationException();
}

@Override
public Value evaluate(TripleSource ts, Value... args) throws ValueExprEvaluationException {
if (args.length != 1) {
throw new QueryEvaluationException("Missing arguments");
}

if (!args[0].isLiteral()) {
throw new QueryEvaluationException(String.format("Non-literal value: %s", args[0]));
}
Literal l = (Literal) args[0];

ExtendedTripleSource extTs = (ExtendedTripleSource) ts;
Response<Embedding> resp = extTs.getQueryHelper(EmbeddingModel.class).embed(l.getLabel());
float[] vec = resp.content().vector();
Float[] arr = new Float[vec.length];
for (int i = 0; i < vec.length; i++) {
arr[i] = Float.valueOf(vec[i]);
}
return new ArrayLiteral((Object[]) arr);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.msd.gin.halyard.sail.model.embedding;

import java.util.HashMap;
import java.util.Map;

import org.junit.jupiter.api.Test;

import dev.langchain4j.model.localai.LocalAiEmbeddingModel;

public class EmbeddingModelQueryHelperProviderTest {
@Test
public void testLocalAI() throws Exception {
EmbeddingModelQueryHelperProvider provider = new EmbeddingModelQueryHelperProvider();
Map<String, String> config = new HashMap<>();
config.put("model.class", LocalAiEmbeddingModel.class.getName());
config.put("baseUrl", "http://localhost");
config.put("modelName", "llama3");
provider.createQueryHelper(config);
}
}

0 comments on commit 289e72d

Please sign in to comment.