Skip to content

Commit

Permalink
Add lang model deps to map-reduce.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Oct 2, 2024
1 parent ce8f2f4 commit fc5018e
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 397 deletions.
5 changes: 3 additions & 2 deletions sail/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,14 @@

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<artifactId>langchain4j-core</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<version>${langchain4j.version}</version>
<scope>test</scope>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ public final QueryPreparer newQueryPreparer() {
public final <T> T getQueryHelper(Class<T> qhType) {
Object qh = queryHelpers.get(qhType);
if (qh == null) {
throw new QueryEvaluationException(String.format("%s is not available", qhType.getName()));
throw new QueryEvaluationException(String.format("Query helper %s is not registered", qhType.getName()));
}
return qhType.cast(qh);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@

import org.junit.jupiter.api.Test;

import dev.langchain4j.model.localai.LocalAiEmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;

public class EmbeddingModelQueryHelperProviderTest {
@Test
public void testLocalAI() throws Exception {
public void testProvider() throws Exception {
EmbeddingModelQueryHelperProvider provider = new EmbeddingModelQueryHelperProvider();
Map<String, String> config = new HashMap<>();
config.put("model.class", LocalAiEmbeddingModel.class.getName());
config.put("baseUrl", "http://localhost");
config.put("modelName", "llama3");
provider.createQueryHelper(config);
config.put("model.class", dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel.class.getName());
EmbeddingModel model = provider.createQueryHelper(config);
model.embed("foobar");
}
}
45 changes: 28 additions & 17 deletions sdk/pom.xml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<artifactId>halyard-sdk</artifactId>
<packaging>pom</packaging>
Expand Down Expand Up @@ -28,21 +30,22 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.eclipse.rdf4j</groupId>
<artifactId>rdf4j-console</artifactId>
<version>${rdf4j.version}</version>
<exclusions>

<dependency>
<groupId>org.eclipse.rdf4j</groupId>
<artifactId>rdf4j-console</artifactId>
<version>${rdf4j.version}</version>
<exclusions>
<exclusion>
<groupId>org.eclipse.rdf4j</groupId>
<artifactId>rdf4j-spin</artifactId>
</exclusion>
<exclusion>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</exclusion>
</exclusions>
</dependency>
<exclusion>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<!-- for spark -->
<groupId>org.slf4j</groupId>
Expand Down Expand Up @@ -87,15 +90,23 @@
<phase>process-classes</phase>
<configuration>
<target>
<zip destfile="${project.build.directory}/rdf4j-repository-api-${rdf4j.version}.jar" update="true">
<zip
destfile="${project.build.directory}/rdf4j-repository-api-${rdf4j.version}.jar"
update="true">
<fileset dir="src/main/patches/rdf4j-client" />
</zip>
<get src="https://github.com/pulquero/hbase/releases/download/rel%2F${hbase.version}%2B${hadoop.version}/hbase-${hbase.version}-bin.tar.gz" dest="${project.build.directory}/hbase.tar.gz" skipexisting="true"/>
<untar src="${project.build.directory}/hbase.tar.gz" dest="${project.build.directory}/hbase-libs" compression="gzip">
<get
src="https://github.com/pulquero/hbase/releases/download/rel%2F${hbase.version}%2B${hadoop.version}/hbase-${hbase.version}-bin.tar.gz"
dest="${project.build.directory}/hbase.tar.gz"
skipexisting="true" />
<untar src="${project.build.directory}/hbase.tar.gz"
dest="${project.build.directory}/hbase-libs"
compression="gzip">
<patternset>
<include name="hbase-${hbase.version}/lib/hbase-*-${hbase.version}.jar"/>
<include
name="hbase-${hbase.version}/lib/hbase-*-${hbase.version}.jar" />
</patternset>
<mapper type="flatten"/>
<mapper type="flatten" />
</untar>
</target>
</configuration>
Expand Down
17 changes: 17 additions & 0 deletions tools/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,23 @@
<version>3.1</version>
<scope>runtime</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-local-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>

<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>halyard-common</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.msd.gin.halyard.common.KeyspaceConnection;
import com.msd.gin.halyard.common.RDFFactory;
import com.msd.gin.halyard.common.StatementIndices;
import com.msd.gin.halyard.rio.HRDFParser;
import com.msd.gin.halyard.util.Version;

import java.io.IOException;
Expand Down Expand Up @@ -51,6 +52,7 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hbase.TableName;
import org.apache.hadoop.hbase.mapreduce.TableMapReduceUtil;
import org.apache.hadoop.hbase.mapreduce.TableMapper;
import org.apache.hadoop.hbase.regionserver.BloomType;
import org.apache.hadoop.hbase.tool.BulkLoadHFiles;
Expand All @@ -64,9 +66,18 @@
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.algebra.evaluation.QueryBindingSet;
import org.eclipse.rdf4j.rio.helpers.NTriplesUtil;
import org.eclipse.rdf4j.rio.nquads.NQuadsParserFactory;
import org.eclipse.rdf4j.rio.ntriples.NTriplesParserFactory;
import org.eclipse.rdf4j.rio.rdfjson.RDFJSONParserFactory;
import org.eclipse.rdf4j.rio.rdfxml.RDFXMLParserFactory;
import org.eclipse.rdf4j.rio.trig.TriGParserFactory;
import org.eclipse.rdf4j.rio.trix.TriXParserFactory;
import org.eclipse.rdf4j.rio.turtle.TurtleParserFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import dev.langchain4j.model.embedding.EmbeddingModel;

/**
*
* @author Adam Sotona (MSD)
Expand Down Expand Up @@ -276,11 +287,36 @@ protected final List<Option> getRequiredOptions() {
return requiredOptions;
}

protected static final boolean isDryRun(Configuration conf) {
protected static boolean isDryRun(Configuration conf) {
return conf.getBoolean(DRY_RUN_PROPERTY, false);
}

protected static final void bulkLoad(Job job, TableName tableName, Path workDir) throws IOException {
protected static void addRioDependencies(Configuration conf) throws IOException {
TableMapReduceUtil.addDependencyJarsForClasses(conf,
TurtleParserFactory.class,
TriXParserFactory.class,
TriGParserFactory.class,
NTriplesParserFactory.class,
NQuadsParserFactory.class,
RDFXMLParserFactory.class,
RDFJSONParserFactory.class,
HRDFParser.Factory.class
);
}

protected static void addLangModelDependencies(Configuration conf) throws IOException {
TableMapReduceUtil.addDependencyJarsForClasses(conf,
EmbeddingModel.class,
dev.langchain4j.model.embedding.onnx.AbstractInProcessEmbeddingModel.class,
dev.langchain4j.model.embedding.onnx.allminilml6v2q.AllMiniLmL6V2QuantizedEmbeddingModel.class,
ai.onnxruntime.OrtEnvironment.class,
dev.langchain4j.model.localai.LocalAiEmbeddingModel.class,
dev.ai4j.openai4j.OpenAiClient.class,
dev.langchain4j.model.ollama.OllamaEmbeddingModel.class
);
}

protected static void bulkLoad(Job job, TableName tableName, Path workDir) throws IOException {
// ensure job configuration is used
Configuration conf = job.getConfiguration();
if (isDryRun(conf)) {
Expand All @@ -292,7 +328,7 @@ protected static final void bulkLoad(Job job, TableName tableName, Path workDir)
}
}

protected static final void addBloomFilterConfig(Configuration conf, TableName tableName) {
protected static void addBloomFilterConfig(Configuration conf, TableName tableName) {
byte[] tableAndFamily = HalyardTableUtils.getTableNameSuffixedWithFamily(tableName.toBytes());
Map<byte[], String> bloomTypeMap = createFamilyConfValueMap(conf, "hbase.hfileoutputformat.families.bloomtype");
String bloomType = bloomTypeMap.get(tableAndFamily);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
import org.apache.hadoop.hbase.mapreduce.HFileOutputFormat2;
import org.apache.hadoop.hbase.mapreduce.TableMapReduceUtil;
import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos;
import org.apache.hadoop.hbase.util.CommonFSUtils;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.mapreduce.Job;
Expand All @@ -65,10 +63,6 @@
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.model.ValueFactory;
import org.eclipse.rdf4j.model.impl.SimpleValueFactory;
import org.eclipse.rdf4j.rio.RDFFormat;
import org.eclipse.rdf4j.rio.RDFParser;
import org.eclipse.rdf4j.rio.Rio;
import org.eclipse.rdf4j.rio.helpers.AbstractRDFHandler;
import org.eclipse.rdf4j.rio.helpers.NTriplesUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -234,15 +228,6 @@ public int run(CommandLine cmd) throws Exception {
configureBoolean(cmd, "dry-run");
String snapshotPath = getConf().get(SNAPSHOT_PATH);

TableMapReduceUtil.addDependencyJarsForClasses(getConf(),
NTriplesUtil.class,
Rio.class,
AbstractRDFHandler.class,
RDFFormat.class,
RDFParser.class,
Table.class,
HBaseConfiguration.class,
AuthenticationProtos.class);
HBaseConfiguration.addHbaseResources(getConf());

RDFFactory rdfFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hbase.HBaseConfiguration;
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.mapreduce.TableMapReduceUtil;
import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.Job;
Expand All @@ -64,10 +62,6 @@
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.evaluation.QueryBindingSet;
import org.eclipse.rdf4j.query.algebra.evaluation.QueryEvaluationStep;
import org.eclipse.rdf4j.rio.RDFFormat;
import org.eclipse.rdf4j.rio.RDFParser;
import org.eclipse.rdf4j.rio.Rio;
import org.eclipse.rdf4j.rio.helpers.AbstractRDFHandler;
import org.eclipse.rdf4j.rio.helpers.NTriplesUtil;
import org.elasticsearch.hadoop.mr.EsOutputFormat;
import org.slf4j.Logger;
Expand Down Expand Up @@ -343,15 +337,8 @@ private static List<JsonInfo> run(Configuration conf, String queryFiles, String
conf.set("es.input.json", "yes");
}

TableMapReduceUtil.addDependencyJarsForClasses(conf,
NTriplesUtil.class,
Rio.class,
AbstractRDFHandler.class,
RDFFormat.class,
RDFParser.class,
Table.class,
HBaseConfiguration.class,
AuthenticationProtos.class);
addRioDependencies(conf);
addLangModelDependencies(conf);
if (System.getProperty("exclude.es-hadoop") == null) {
TableMapReduceUtil.addDependencyJarsForClasses(conf, EsOutputFormat.class);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -717,12 +717,7 @@ protected int run(CommandLine cmd) throws Exception {
tableDesc = hTable.getDescriptor();
RegionLocator regionLocator = conn.getRegionLocator(tableDesc.getTableName());
HFileOutputFormat2.configureIncrementalLoad(job, tableDesc, regionLocator);
TableMapReduceUtil.addDependencyJarsForClasses(job.getConfiguration(),
NTriplesUtil.class,
Rio.class,
AbstractRDFHandler.class,
RDFFormat.class,
RDFParser.class);
addRioDependencies(job.getConfiguration());
}
try (Keyspace keyspace = HalyardTableUtils.getKeyspace(getConf(), conn, tableDesc.getTableName(), null, null)) {
try (KeyspaceConnection ksConn = keyspace.getConnection()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
import org.apache.hadoop.hbase.mapreduce.HFileOutputFormat2;
import org.apache.hadoop.hbase.mapreduce.TableMapReduceUtil;
import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos;
import org.apache.hadoop.io.NullWritable;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
Expand All @@ -76,10 +74,6 @@
import org.eclipse.rdf4j.query.parser.ParsedUpdate;
import org.eclipse.rdf4j.query.parser.QueryParserUtil;
import org.eclipse.rdf4j.repository.sail.SailRepositoryConnection;
import org.eclipse.rdf4j.rio.RDFFormat;
import org.eclipse.rdf4j.rio.RDFParser;
import org.eclipse.rdf4j.rio.Rio;
import org.eclipse.rdf4j.rio.helpers.AbstractRDFHandler;
import org.eclipse.rdf4j.rio.helpers.NTriplesUtil;
import org.eclipse.rdf4j.sail.SailException;
import org.slf4j.Logger;
Expand Down Expand Up @@ -353,15 +347,6 @@ static List<JsonInfo> executeUpdate(Configuration conf, String source, String qu

private static List<JsonInfo> run(Configuration conf, String queryFiles, String query, String workdir) throws IOException, InterruptedException, ClassNotFoundException {
String source = conf.get(TABLE_NAME_PROPERTY);
TableMapReduceUtil.addDependencyJarsForClasses(conf,
NTriplesUtil.class,
Rio.class,
AbstractRDFHandler.class,
RDFFormat.class,
RDFParser.class,
Table.class,
HBaseConfiguration.class,
AuthenticationProtos.class);
HBaseConfiguration.addHbaseResources(conf);
// get bindings from merged configs
BindingSet bindings = AbstractHalyardTool.getBindings(conf, SimpleValueFactory.getInstance());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,8 @@
import org.apache.hadoop.hbase.client.ConnectionFactory;
import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.io.ImmutableBytesWritable;
import org.apache.hadoop.hbase.mapreduce.TableMapReduceUtil;
import org.apache.hadoop.hbase.protobuf.generated.AuthenticationProtos;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.hbase.util.CommonFSUtils;
import org.apache.hadoop.io.NullWritable;
Expand All @@ -76,11 +74,6 @@
import org.eclipse.rdf4j.model.base.CoreDatatype.GEO;
import org.eclipse.rdf4j.model.base.CoreDatatype.XSD;
import org.eclipse.rdf4j.model.impl.SimpleValueFactory;
import org.eclipse.rdf4j.rio.RDFFormat;
import org.eclipse.rdf4j.rio.RDFParser;
import org.eclipse.rdf4j.rio.Rio;
import org.eclipse.rdf4j.rio.helpers.AbstractRDFHandler;
import org.eclipse.rdf4j.rio.helpers.NTriplesUtil;
import org.elasticsearch.hadoop.mr.EsOutputFormat;
import org.json.JSONArray;
import org.json.JSONObject;
Expand Down Expand Up @@ -272,15 +265,6 @@ public int run(CommandLine cmd) throws Exception {
getConf().setBoolean(confProperty(TOOL_NAME, "fields."+field), true);
}

TableMapReduceUtil.addDependencyJarsForClasses(getConf(),
NTriplesUtil.class,
Rio.class,
AbstractRDFHandler.class,
RDFFormat.class,
RDFParser.class,
Table.class,
HBaseConfiguration.class,
AuthenticationProtos.class);
if (System.getProperty("exclude.es-hadoop") == null) {
TableMapReduceUtil.addDependencyJarsForClasses(getConf(), EsOutputFormat.class);
}
Expand Down
Loading

0 comments on commit fc5018e

Please sign in to comment.