Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres Upsert Modifications #176

Merged
merged 6 commits into from
Aug 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions FlySpring/edgechain-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.github.f4b6a3</groupId>
<artifactId>uuid-creator</artifactId>
<version>5.2.0</version>
</dependency>

</dependencies>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import com.edgechain.lib.embeddings.WordEmbeddings;
import com.edgechain.lib.endpoint.Endpoint;
import com.edgechain.lib.index.enums.PostgresDistanceMetric;
import com.edgechain.lib.index.responses.PostgresResponse;
import com.edgechain.lib.response.StringResponse;
import com.edgechain.lib.retrofit.PostgresService;
import com.edgechain.lib.retrofit.client.RetrofitClientInstance;
Expand All @@ -26,6 +27,7 @@ public class PostgresEndpoint extends Endpoint {
private PostgresDistanceMetric metric;
private int dimensions;
private int topK;
private String fileName;

public PostgresEndpoint() {}

Expand All @@ -44,6 +46,10 @@ public PostgresEndpoint(String tableName, String namespace, RetryPolicy retryPol
this.namespace = namespace;
}

public void setFileName(String fileName) {
this.fileName = fileName;
}

public void setTableName(String tableName) {
this.tableName = tableName;
}
Expand All @@ -62,6 +68,10 @@ public String getNamespace() {

// Getters

public String getFileName() {
return fileName;
}

public WordEmbeddings getWordEmbeddings() {
return wordEmbeddings;
}
Expand All @@ -82,7 +92,11 @@ public PostgresDistanceMetric getMetric() {
public Observable<StringResponse> upsert(WordEmbeddings wordEmbeddings, int dimension) {
this.wordEmbeddings = wordEmbeddings;
this.dimensions = dimension;
return Observable.fromSingle(postgresService.upsert(this));
if(fileName != null) {
return Observable.fromSingle(postgresService.upsertWithFilename(this));
} else {
return Observable.fromSingle(postgresService.upsert(this));
}
}

public Observable<List<WordEmbeddings>> query(
Expand All @@ -92,6 +106,13 @@ public Observable<List<WordEmbeddings>> query(
this.metric = metric;
return Observable.fromSingle(this.postgresService.query(this));
}
public Observable<List<PostgresResponse>> queryWithFilename(
WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) {
this.wordEmbeddings = wordEmbeddings;
this.topK = topK;
this.metric = metric;
return Observable.fromSingle(this.postgresService.queryWithFilename(this));
}

public Observable<StringResponse> deleteAll() {
return Observable.fromSingle(this.postgresService.deleteAll(this));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import com.edgechain.lib.endpoint.impl.PostgresEndpoint;
import com.edgechain.lib.index.enums.PostgresDistanceMetric;
import com.edgechain.lib.index.repositories.PostgresClientRepository;
import com.edgechain.lib.index.responses.PostgresResponse;
import com.edgechain.lib.response.StringResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import io.reactivex.rxjava3.core.Observable;

import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.util.*;

public class PostgresClient {
Expand Down Expand Up @@ -51,6 +55,30 @@ public EdgeChain<StringResponse> upsert(WordEmbeddings wordEmbeddings) {
}),
postgresEndpoint);
}
public EdgeChain<StringResponse> upsertWithFilename(WordEmbeddings wordEmbeddings) {

return new EdgeChain<>(
Observable.create(
emitter -> {
try {
// Create Table
this.repository.createTable(postgresEndpoint);

String input = wordEmbeddings.getId().replaceAll("'", "");

// Upsert Embeddings
this.repository.upsertEmbeddingsWithFilename(
postgresEndpoint.getTableName(), input, wordEmbeddings, this.namespace, postgresEndpoint.getFileName());

emitter.onNext(new StringResponse("Upserted"));
emitter.onComplete();

} catch (final Exception e) {
emitter.onError(e);
}
}),
postgresEndpoint);
}

public EdgeChain<List<WordEmbeddings>> query(
WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) {
Expand All @@ -70,7 +98,7 @@ public EdgeChain<List<WordEmbeddings>> query(
List<WordEmbeddings> wordEmbeddingsList = new ArrayList<>();

for (Map row : rows) {
wordEmbeddingsList.add(new WordEmbeddings((String) row.get("id")));
wordEmbeddingsList.add(new WordEmbeddings((String) row.get("raw")));
}

emitter.onNext(wordEmbeddingsList);
Expand All @@ -82,6 +110,42 @@ public EdgeChain<List<WordEmbeddings>> query(
}),
postgresEndpoint);
}
public EdgeChain<List<PostgresResponse>> queryWithFilename(
WordEmbeddings wordEmbeddings, PostgresDistanceMetric metric, int topK) {

return new EdgeChain<>(
Observable.create(
emitter -> {
try {
List<Map<String, Object>> rows =
this.repository.queryWithFilename(
postgresEndpoint.getTableName(),
this.namespace,
metric,
wordEmbeddings,
topK);

List<PostgresResponse> wordEmbeddingsList = new ArrayList<>();

for (Map row : rows) {
wordEmbeddingsList.add(
new PostgresResponse(
(String) row.get("id"),
new WordEmbeddings((String) row.get("raw")),
(String) row.get("filename"),
(Integer) row.get("sno"),
(Timestamp) row.get("timestamp")
)
);
}
emitter.onNext(wordEmbeddingsList);
emitter.onComplete();
} catch (final Exception e) {
emitter.onError(e);
}
}), postgresEndpoint
);
}

public EdgeChain<StringResponse> deleteAll() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import com.edgechain.lib.endpoint.impl.PostgresEndpoint;
import com.edgechain.lib.index.enums.PostgresDistanceMetric;
import com.edgechain.lib.utils.FloatUtils;
import com.github.f4b6a3.uuid.UuidCreator;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;

import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
Expand All @@ -24,23 +26,42 @@ public void createTable(PostgresEndpoint postgresEndpoint) {
jdbcTemplate.execute("CREATE EXTENSION IF NOT EXISTS vector;");
jdbcTemplate.execute(
String.format(
"CREATE TABLE IF NOT EXISTS %s (id TEXT PRIMARY KEY, embedding"
+ " vector(%s), namespace TEXT);",
"CREATE TABLE IF NOT EXISTS %s (id TEXT PRIMARY KEY, raw TEXT, sno SERIAL, embedding"
+ " vector(%s), namespace TEXT, fileName TEXT, timestamp TIMESTAMP);",
postgresEndpoint.getTableName(), postgresEndpoint.getDimensions()));
}

@Transactional
public void upsertEmbeddings(
String tableName, String input, WordEmbeddings wordEmbeddings, String namespace) {
String id = UuidCreator.getTimeOrderedEpoch().toString();
LocalDateTime timestamp = LocalDateTime.now();
jdbcTemplate.execute(
String.format(
"INSERT INTO %s (id, raw, embedding, namespace, timestamp) VALUES ('%s', '%s', '%s', '%s', '%s')\n"
+ " ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;",
tableName,
id,
input,
Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())),
namespace, timestamp.toString()));
}

//Use this function to insert embeddings with filename
@Transactional
public void upsertEmbeddingsWithFilename(
String tableName, String input, WordEmbeddings wordEmbeddings, String namespace, String fileName) {
String id = UuidCreator.getTimeOrderedEpoch().toString();
LocalDateTime timestamp = LocalDateTime.now();
jdbcTemplate.execute(
String.format(
"INSERT INTO %s (id, embedding, namespace) VALUES ('%s', '%s', '%s')\n"
"INSERT INTO %s (id, raw, embedding, namespace, fileName, timestamp) VALUES ('%s', '%s', '%s', '%s', '%s', '%s')\n"
+ " ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding;",
tableName,
id,
input,
Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())),
namespace));
namespace, fileName, timestamp.toString()));
}

@Transactional(readOnly = true)
Expand All @@ -53,7 +74,26 @@ public List<Map<String, Object>> query(

return jdbcTemplate.queryForList(
String.format(
"SELECT id FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' LIMIT %s;",
"SELECT raw FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' LIMIT %s;",
tableName,
namespace,
PostgresDistanceMetric.getDistanceMetric(metric),
Arrays.toString(FloatUtils.toFloatArray(wordEmbeddings.getValues())),
topK));
}

//Use this function to query embeddings with filename
@Transactional(readOnly = true)
public List<Map<String, Object>> queryWithFilename(
String tableName,
String namespace,
PostgresDistanceMetric metric,
WordEmbeddings wordEmbeddings,
int topK) {

return jdbcTemplate.queryForList(
String.format(
"SELECT id, sno, raw, filename, timestamp FROM %s WHERE namespace='%s' ORDER BY embedding %s '%s' LIMIT %s;",
tableName,
namespace,
PostgresDistanceMetric.getDistanceMetric(metric),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package com.edgechain.lib.index.responses;

import com.edgechain.lib.embeddings.WordEmbeddings;

import java.sql.Timestamp;

public class PostgresResponse {
private String id;
private WordEmbeddings wordEmbeddings;
private String fileName;
private Integer sno;
private Timestamp timestamp;

public PostgresResponse() {
}

public PostgresResponse(String id, WordEmbeddings wordEmbeddings, String fileName, Integer sno, Timestamp timestamp) {
this.id = id;
this.wordEmbeddings = wordEmbeddings;
this.fileName = fileName;
this.sno = sno;
this.timestamp = timestamp;
}

public String getId() {
return id;
}

public void setId(String id) {
this.id = id;
}

public Integer getSno() {
return sno;
}

public void setSno(Integer sno) {
this.sno = sno;
}

public Timestamp getTimestamp() {
return timestamp;
}

public void setTimestamp(Timestamp timestamp) {
this.timestamp = timestamp;
}

public WordEmbeddings getWordEmbeddings() {
return wordEmbeddings;
}

public void setWordEmbeddings(WordEmbeddings wordEmbeddings) {
this.wordEmbeddings = wordEmbeddings;
}

public String getFileName() {
return fileName;
}

public void setFileName(String fileName) {
this.fileName = fileName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.edgechain.lib.embeddings.WordEmbeddings;
import com.edgechain.lib.endpoint.impl.PostgresEndpoint;
import com.edgechain.lib.index.responses.PostgresResponse;
import com.edgechain.lib.response.StringResponse;
import io.reactivex.rxjava3.core.Single;
import java.util.List;
Expand All @@ -13,11 +14,16 @@ public interface PostgresService {

@POST(value = "index/postgres/upsert")
Single<StringResponse> upsert(@Body PostgresEndpoint postgresEndpoint);
@POST(value = "index/postgres/upsert-filename")
Single<StringResponse> upsertWithFilename(@Body PostgresEndpoint postgresEndpoint);

//
@POST(value = "index/postgres/query")
Single<List<WordEmbeddings>> query(@Body PostgresEndpoint postgresEndpoint);

@POST(value = "index/postgres/query-filename")
Single<List<PostgresResponse>> queryWithFilename(@Body PostgresEndpoint postgresEndpoint);

@HTTP(method = "DELETE", path = "index/postgres/deleteAll", hasBody = true)
Single<StringResponse> deleteAll(@Body PostgresEndpoint postgresEndpoint);
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.edgechain.lib.embeddings.WordEmbeddings;
import com.edgechain.lib.endpoint.impl.PostgresEndpoint;
import com.edgechain.lib.index.client.impl.PostgresClient;
import com.edgechain.lib.index.responses.PostgresResponse;
import com.edgechain.lib.response.StringResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import io.reactivex.rxjava3.core.Single;
Expand All @@ -20,6 +21,12 @@ public Single<StringResponse> upsert(@RequestBody PostgresEndpoint postgresEndpo
new PostgresClient(postgresEndpoint).upsert(postgresEndpoint.getWordEmbeddings());
return edgeChain.toSingle();
}
@PostMapping("/upsert-filename")
public Single<StringResponse> upsertWithFilename(@RequestBody PostgresEndpoint postgresEndpoint) {
EdgeChain<StringResponse> edgeChain =
new PostgresClient(postgresEndpoint).upsertWithFilename(postgresEndpoint.getWordEmbeddings());
return edgeChain.toSingle();
}

@PostMapping("/query")
public Single<List<WordEmbeddings>> query(@RequestBody PostgresEndpoint postgresEndpoint) {
Expand All @@ -31,6 +38,16 @@ public Single<List<WordEmbeddings>> query(@RequestBody PostgresEndpoint postgres
postgresEndpoint.getTopK());
return edgeChain.toSingle();
}
@PostMapping("/query-filename")
public Single<List<PostgresResponse>> queryWithFilename(@RequestBody PostgresEndpoint postgresEndpoint) {
EdgeChain<List<PostgresResponse>> edgeChain =
new PostgresClient(postgresEndpoint)
.queryWithFilename(
postgresEndpoint.getWordEmbeddings(),
postgresEndpoint.getMetric(),
postgresEndpoint.getTopK());
return edgeChain.toSingle();
}

@DeleteMapping("/deleteAll")
public Single<StringResponse> deleteAll(@RequestBody PostgresEndpoint postgresEndpoint) {
Expand Down