Skip to content

Commit

Permalink
Got basic Open AI Embeddings being generated and persisted
Browse files Browse the repository at this point in the history
  • Loading branch information
mhughes2k committed Mar 6, 2024
1 parent 2731810 commit bef2f3b
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 53 deletions.
92 changes: 92 additions & 0 deletions search/engine/solrrag/classes/ai/aiclient.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
<?php
namespace core\ai;
require_once($CFG->libdir.'/filelib.php');
/**
* Base client for AI providers that uses simple http request.
*/
class AIClient extends \curl {
private $provider;
public function __construct(
\core\ai\AIProvider $provider
) {
$this->provider = $provider;
$settings = [];
parent::__construct($settings);
$this->setHeader('Authorization: Bearer ' . $this->provider->get('apikey'));
$this->setHeader('Content-Type: application/json');
}

public function get_embeddings_url(): string {
return $this->provider->get('baseurl') . $this->provider->get('embeddings');
}

public function get_chat_completions_url(): string {
return $this->provider->get('baseurl') . $this->provider->get('completions');
}

/**
* @param $document
* @return array
*/
public function embed_query($content): array {
// Send document to back end and return the vector
$usedptokens = $this->provider->get_usage('prompt_tokens');
$totaltokens = $this->provider->get_usage('total_tokens');
mtrace("Prompt tokens: $usedptokens. Total tokens: $totaltokens");
$params = [
"input" => htmlentities($content), // TODO need to do some length checking here!
"model" => $this->provider->get('embeddingmodel')
];
$params = json_encode($params);
// var_dump($this->get_embeddings_url());

$rawresult = $this->post($this->get_embeddings_url(), $params);
// var_dump($rawresult);
$result = json_decode($rawresult, true);
var_dump($result);
$usage = $result['usage'];
$this->provider->increment_prompt_usage($usage['prompt_tokens']);
$this->provider->increment_total_tokens($usage['total_tokens']);
mtrace("Used Prompt tokens: {$usage['prompt_tokens']}. Total tokens: {$usage['total_tokens']}");
$data = $result['data'];
foreach($data as $d) {
if ($d['object'] == "embedding") {
return $d['embedding'];
}
}
$usedptokens = $this->provider->get_usage('prompt_tokens');
$totaltokens = $this->provider->get_usage('total_tokens');
mtrace("Total Used: Prompt tokens: $usedptokens. Total tokens: $totaltokens");
return [];
}
public function embed_documents(array $documents) {
// Go send the documents off to a back end and then return array of each document's vectors.
// But for the minute generate an array of fake vectors of a specific length.
$embeddings = [];
foreach($documents as $doc) {
$embeddings[] = $this->embed_query($doc);
}
return $embeddings;
}
public function fake_embed(array $documents) {
$vectors = [];
foreach ($documents as $document) {
$vectors[] = $this->fake_vector(1356);
}
return $vectors;
}
public function complete($query) {


}
private function fake_vector($length) {
$vector = [];
for ($i = 0; $i < $length; $i++) {
$vector[] = rand(0, 1);
}
return $vector;
}



}
94 changes: 63 additions & 31 deletions search/engine/solrrag/classes/ai/aiprovider.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,37 @@

class AIProvider extends persistent {
// Ultimately this would extend a persistent.
public function __construct($id = 0, stdClass $record = null) {
if ($id > 0) {
$this->raw_set('id', $id);
$this->raw_set('name', "Fake AI Provider");
$this->raw_set('allowembeddings', true);
$this->raw_set('allowquery', true);
}
}


protected static function define_properties()
{
return [
'name' => [
'type' => PARAM_TEXT
],
'apikey' =>[
'type' => PARAM_ALPHANUMEXT
],
'allowembeddings' => [
'type' => PARAM_BOOL
],
'allowquery' => [
'type' => PARAM_BOOL
],
'baseurl' => [
'type' => PARAM_URL
],
'embeddings' => [
'type' => PARAM_URL
],
'embeddingmodel' => [
'type' => PARAM_ALPHANUMEXT
],
'completions' => [
'type' => PARAM_URL
],
'completionmodel' => [
'type' => PARAM_ALPHANUMEXT
]
];
}
Expand All @@ -37,31 +48,43 @@ public function use_for_embeddings(): bool {
public function use_for_query():bool {
return $this->get('allowquery');
}
public function embed_documents(array $documents) {
// Go send the documents off to a back end and then return array of each document's vectors.
// But for the minute generate an array of fake vectors of a specific length.
$vectors = [];
foreach ($documents as $document) {
$vectors[] = $this->fake_vector(1356);
}
return $vectors;
public function get_usage($type) {
return "-";
$key = [
'$type',
$this->get('id'),
$this->get('apikey'),
];
$current = get_config('ai', $key);
return $current;
}
private function fake_vector($length) {
$vector = [];
for ($i = 0; $i < $length; $i++) {
$vector[] = rand(0, 1);
}
return $vector;
public function increment_prompt_usage($change) {
return;
$key = [
'prompttokens',
$this->get('id'),
$this->get('apikey'),
];
$key = implode("_", $key);
$current = get_config('ai', $key);
$new = $current + $change;
set_config($key, $new, 'ai');
}

/**
* @param $document
* @return array
*/
public function embed_query($document): array {
// Send document to back end and return the vector
return $this->fake_vector(1356);
public function increment_total_tokens($change) {
return;
$key = [
'totaltokens',
$this->get('id'),
$this->get('apikey'),
];
$key = implode("_", $key);
$current = get_config('ai', $key);
$new = $current + $change;
set_config($key, $new, 'ai');
}

//public function
// TODO token counting.
/**
* We're overriding this whilst we don't have a real DB table.
* @param $filters
Expand All @@ -75,9 +98,18 @@ public static function get_records($filters = array(), $sort = '', $order = 'ASC
$records = [];
$fake = new static(0, (object) [
'id' => 1,
'name' => "Fake AI Provider"
'name' => "Fake Open AI Provider",
'allowembeddings' => true,
'allowquery' => true,
'baseurl' => 'https://api.openai.com/v1/',
'embeddings' => 'embeddings',
'embeddingmodel' => 'text-embedding-3-small',
'completions' => 'completions',
'completionmodel' => 'gpt-4-turbo-preview',
'apikey'=> ''
]);
array_push($records, $fake);
return $records;
}

}
6 changes: 4 additions & 2 deletions search/engine/solrrag/classes/ai/api.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ public static function get_all_providers() {
return array_values(AIProvider::get_records());
}
public static function get_provider(int $id): AIProvider {
return new AIProvider($id);
$fakes = AIProvider::get_records();
return $fakes[0];

}
}
}
39 changes: 19 additions & 20 deletions search/engine/solrrag/classes/engine.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
// Fudge autoloading!
require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/api.php");
require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiprovider.php");
require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiclient.php");
use \core\ai\AIProvider;
use \core\ai\AIClient;
class engine extends \search_solr\engine {

/**
* @var AIProvider AI rovider object to use to generate embeddings.
*/
protected ?AIProvider $embeddingprovider = null;
protected ?AIClient $aiclient = null;
protected ?AIProvider $aiprovider = null;

public function __construct(bool $alternateconfiguration = false)
{
Expand All @@ -25,9 +28,8 @@ public function __construct(bool $alternateconfiguration = false)
// So we'll fudge this for the moment and leverage an OpenAI Web Service API via a simple HTTP request.
$aiproviderid = 1;
$aiprovider = \core\ai\api::get_provider($aiproviderid);
if ($aiprovider->use_for_embeddings()) {
$this->embeddingprovider = $aiprovider;
}
$this->aiprovider = $aiprovider;
$this->aiclient = !is_null($aiprovider)? new AIClient($aiprovider) : null;
}

public function is_server_ready()
Expand Down Expand Up @@ -75,20 +77,8 @@ public function is_server_ready()
protected function add_stored_file($document, $storedfile)
{
$embeddings = [];
$filedoc = $document->export_file_for_engine($storedfile);
/**
* Should we even attempt to get vectors.
*/
if (!is_null($this->embeddingprovider)) {
// garnish $filedoc with the embedding vector. It would be nice if this could be done
// via the export_file_for_engine() call above, that has no awareness of the engine.
$embeddings = $this->embeddingprovider->embed_documents([$filedoc]);
} else {
// potentially warn that selected provider can't be used for
// generating embeddings for RAG.
}


$filedoc = $document->export_file_for_engine($storedfile);
// Used the underlying implementation

if (!$this->file_is_indexable($storedfile)) {
Expand Down Expand Up @@ -197,15 +187,24 @@ protected function add_stored_file($document, $storedfile)
}
}
}
if (count($embeddings) > 0) {
$vector = $embeddings[0];
/**
* Since solr has given us back the content, we can now send it off to the AI provider.
*/
if ($this->aiprovider->use_for_embeddings() && $this->aiclient) {
// garnish $filedoc with the embedding vector. It would be nice if this could be done
// via the export_file_for_engine() call above, that has no awareness of the engine.
// We expect $filedoc['content'] to be set.
$vector = $this->aiclient->embed_query($filedoc['content']);
$vlength = count($vector);
$vectorfield = "solr_vector_" . $vlength;
// TODO Check if a field of this length actually exists or not.
$filedoc[$vectorfield] = $vector;
debugging("Using vector field $vectorfield");
} else {
// potentially warn that selected provider can't be used for
// generating embeddings for RAG.
}
$this->add_solr_document($filedoc);
exit("Goodbye");
return;
}
} else {
Expand Down

0 comments on commit bef2f3b

Please sign in to comment.