Skip to content

Commit

Permalink
Vector search working for content indexed by solr
Browse files Browse the repository at this point in the history
  • Loading branch information
mhughes2k committed Mar 19, 2024
1 parent 1978c36 commit f299b29
Show file tree
Hide file tree
Showing 9 changed files with 305 additions and 56 deletions.
3 changes: 3 additions & 0 deletions mod/xaichat/templates/conversation.mustache
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@
{{> mod_xaichat/message}}
{{/messages}}
<hr />
{{#rawmessages}}
{{{rawmessages}}}
{{/rawmessages}}
56 changes: 34 additions & 22 deletions mod/xaichat/view.php
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,12 @@
$progress->update(1, $totalsteps,'Looking for relevant context');
$vector = $aiclient->embed_query($data->userprompt);
$search = \core_search\manager::instance(true, true);
$settings = [
'similarity' => true,
'vector' => $vector,
];
$limit = 0;
// Some of these values can't be "trusted" to the end user to supply, via something
// like a form, nor can they be entirely left to the plugin developer.
$settings = $aiprovider->get_settings_for_user($USER);
$settings['vector'] = $vector;
$settings['userquery'] = $data->userprompt;
$docs = $search->search((object)$settings);
var_dump($docs);

// Perform "R" from RAG, finding documents from within the context that are similar to the user's prompt.
// Add the retrieved documents to the context for this chat by generating some system messages with the content
Expand All @@ -110,37 +109,42 @@
foreach ($docs as $doc) {
$context[] = $doc->content;
}
$_SESSION[$aicontextkey]['messages'][] = (object)[
$prompt = (object)[
"role" => "system",
"content" => "Use the following context to answer questions:" . implode("\n",$context)
"content" => "Use the following context to answer following question:" . implode("\n",$context)
];
}

// Add the user's new prompt to the messages.
$_SESSION[$aicontextkey]['messages'][] = $prompt;
}
$progress->update(2, $totalsteps,'Attaching user prompt');
$_SESSION[$aicontextkey]['messages'][] = (object)[
// $_SESSION[$aicontextkey]['messages'][]
$prompt = (object)[
"role" => "user",
"content" => $data->userprompt
];
// Pass the whole context over the AI to summarise.
var_dump($_SESSION[$aicontextkey]['messages']);
$_SESSION[$aicontextkey]['messages'][] = $prompt;

// Pass the whole context over the AI to summarise.
$progress->update(3, $totalsteps, 'Waiting for response');

$airesults = $aiclient->chat($_SESSION[$aicontextkey]['messages']);
$_SESSION[$aicontextkey]['messages'] = array_merge($_SESSION[$aicontextkey]['messages'],$airesults);
$progress->update(4, $totalsteps, 'Got Response');

// We stash the data in the session temporarily (should go into an activity-user store in database) but this
// is fast and dirty, and then we do a redirect so that we don't double up the request if the user hit's
// refresh.
redirect(new \moodle_url('/mod/xaichat/view.php', ['id' => $cm->id]));
$next = new \moodle_url('/mod/xaichat/view.php', ['id' => $cm->id]);
redirect($next);
} else if ($chatform->is_cancelled()) {
$_SESSION[$aicontextkey] = [
'messages'=>[]
];
$prompt = (object)[
"role" => "system",
"content" => "You are a helpful AI. You should only use information you know. Only use information that is relevant to the question."
];
$_SESSION[$aicontextkey]['messages'][] = $prompt;
} else {
// Clear session on first view of form.

$toform = [
'id' => $id,
'aiproviderid' => $moduleinstance->aiproviderid,
Expand All @@ -154,19 +158,27 @@

$displaymessages = [];
foreach ($_SESSION[$aicontextkey]['messages'] as $message) {
$displaymessages[] = [
"role" => $message->role == "user" ? $userpic : \html_writer::tag("strong", $aipic),
"content" => format_text($message->content, FORMAT_MARKDOWN)
];
if ($message->role != "system") {
$displaymessages[] = [
"role" => $message->role == "user" ? $userpic : \html_writer::tag("strong", $aipic),
"content" => format_text($message->content, FORMAT_MARKDOWN)
];
}
}
$displaymessages = array_reverse($displaymessages);
$tcontext = [
"userpic" => new user_picture($USER),
"messages" => $displaymessages
];
$chatform->display();

echo $OUTPUT->render_from_template("mod_xaichat/conversation", $tcontext);

$chatform->display();
if (false) {
echo html_writer::tag("pre", print_r($_SESSION[$aicontextkey]['messages'],1));
}



//echo \html_writer::tag('pre', print_r($displaymessages,1));
//echo \html_writer::tag('pre', print_r($_SESSION[$aicontextkey]['messages'],1));
Expand Down
27 changes: 16 additions & 11 deletions search/engine/solrrag/classes/ai/aiclient.php
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<?php
namespace core\ai;
require_once($CFG->libdir.'/filelib.php');
use core\ai\AiException;
/**
* Base client for AI providers that uses simple http request.
*/
Expand Down Expand Up @@ -40,16 +41,20 @@ public function chat($messages) {
$params = json_encode($params);
$rawresult = $this->post($this->get_chat_completions_url(), $params);
$jsonresult = json_decode($rawresult);
if (!isset($jsonresult->choices)) {
exit();
return [];
if (isset($jsonresult->error)) {
throw new AiException("Error: " . $jsonresult->error->message . ":". print_r($messages, true));
//return "Error: " . $jsonresult->error->message . ":". print_r($messages, true);
}
$result = $this->convert_chat_completion($jsonresult->choices);
if (isset($jsonresult->usage)) {
$this->provider->increment_prompt_usage($jsonresult->usage->prompt_tokens);
$this->provider->increment_completion_tokens($jsonresult->usage->completion_tokens);
$this->provider->increment_total_tokens($jsonresult->usage->total_tokens);
$result = [];
if (isset($jsonresult->choices)) {
$result = $this->convert_chat_completion($jsonresult->choices);
if (isset($jsonresult->usage)) {
$this->provider->increment_prompt_usage($jsonresult->usage->prompt_tokens);
$this->provider->increment_completion_tokens($jsonresult->usage->completion_tokens);
$this->provider->increment_total_tokens($jsonresult->usage->total_tokens);
}
}

return $result;
}

Expand All @@ -73,7 +78,7 @@ public function embed_query($content): array {
// Send document to back end and return the vector
$usedptokens = $this->provider->get_usage('prompt_tokens');
$totaltokens = $this->provider->get_usage('total_tokens');
mtrace("Prompt tokens: $usedptokens. Total tokens: $totaltokens");
// mtrace("Prompt tokens: $usedptokens. Total tokens: $totaltokens");
$params = [
"input" => htmlentities($content), // TODO need to do some length checking here!
"model" => $this->provider->get('embeddingmodel')
Expand All @@ -88,7 +93,7 @@ public function embed_query($content): array {
$usage = $result['usage'];
$this->provider->increment_prompt_usage($usage['prompt_tokens']);
$this->provider->increment_total_tokens($usage['total_tokens']);
mtrace("Used Prompt tokens: {$usage['prompt_tokens']}. Total tokens: {$usage['total_tokens']}");
// mtrace("Used Prompt tokens: {$usage['prompt_tokens']}. Total tokens: {$usage['total_tokens']}");
$data = $result['data'];
foreach($data as $d) {
if ($d['object'] == "embedding") {
Expand All @@ -97,7 +102,7 @@ public function embed_query($content): array {
}
$usedptokens = $this->provider->get_usage('prompt_tokens');
$totaltokens = $this->provider->get_usage('total_tokens');
mtrace("Total Used: Prompt tokens: $usedptokens. Total tokens: $totaltokens");
// mtrace("Total Used: Prompt tokens: $usedptokens. Total tokens: $totaltokens");
return [];
}
public function embed_documents(array $documents) {
Expand Down
7 changes: 7 additions & 0 deletions search/engine/solrrag/classes/ai/aiexception.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
<?php

namespace core\ai;

class AiException extends \moodle_exception {

}
93 changes: 89 additions & 4 deletions search/engine/solrrag/classes/ai/aiprovider.php
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
<?php
// We're mocking a core Moodle "AI" Subsystem a la Oauth 2

namespace core\ai;
require_once($CFG->dirroot . "/search/engine/solrrag/lib.php");

use \core\persistent;
use core_course_category;

class AIProvider extends persistent {
// Ultimately this would extend a persistent.

const CONTEXT_ALL_MY_COURSES = -1;

protected static function define_properties()
{
Expand Down Expand Up @@ -37,7 +39,17 @@ protected static function define_properties()
],
'completionmodel' => [
'type' => PARAM_ALPHANUMEXT
]
],
// What context is this provider attached to.
// If null, it's a global provider.
// If -1 its limited to user's own courses.
'context' => [
'type' => PARAM_INT
],
// If true, only courses that the user is enrolled in will be searched.
'onlyenrolledcourses' => [
'type' => PARAM_BOOL
],
];
}

Expand Down Expand Up @@ -95,6 +107,61 @@ public function increment_total_tokens($change) {
set_config($key, $new, 'ai');
}

/**
* Returns appropriate search settings based on
* provider configuration.
*/
public function get_settings() {
// `userquery` and `vector` will be filled at run time.
$settings = [
'userquery'=> null,
'vector' => null,
// `similarity` is a boolean that determines if the query should use vector similarity search.
'similarity' => true,
'areaids' => [],
// `excludeareaids` is an array of areaids that should be excluded from the search.
'excludeareaids'=> ["core_user-user"], // <-- This may be should be in control of the AI Provider.
'courseids' => [], // This of course IDs that result should be limited to.
];
return $settings;
}

/**
* Gets user specific settings.
*
* This takes on some of the function that the manager code did.
*/
public function get_settings_for_user($user) {
$usersettings = $this->get_settings();

// This is basically manager::build_limitcourseids().
$mycourseids = enrol_get_my_courses(array('id', 'cacherev'), 'id', 0, [], false);
$onlyenrolledcourses = $this->get('onlyenrolledcourses');
$courseids = [];
if ($this->get('context') == self::CONTEXT_ALL_MY_COURSES) {
$courseids = array_keys($mycourseids);
} else {
$context = \context::instance_by_id($this->get('context'));
if ($context->contextlevel == CONTEXT_COURSE) {
// Check that the specific course is also in the user's list of courses.
$courseids = array_intersect([$context->instanceid], $mycourseids);
} else if ($context->contextlevel == CONTEXT_COURSECAT) {
// CourseIDs will be all courses in the category,
// optionally that the user is enrolled in
$category = core_course_category::get($context->instanceid);
$categorycourseids = $category->get_courses([
'recursive'=>true,
'idonly' => true
]);
} else if ($context->contextlevel == CONTEXT_SYSTEM) {
// No restrictions anywhere.
}
}
$usersettings['courseids'] = $courseids;

return $usersettings;
}

//public function
// TODO token counting.
/**
Expand All @@ -119,7 +186,25 @@ public static function get_records($filters = array(), $sort = '', $order = 'ASC
'embeddingmodel' => 'text-embedding-3-small',
'completions' => 'chat/completions',
'completionmodel' => 'gpt-4-turbo-preview',
'apikey'=> $_ENV['OPENAIKEY']
'apikey'=> $_ENV['OPENAIKEY'],
'context' => \context_system::instance()->id,
//null, // Global AI Provider
'onlyenrolledcourses' => true
]);
array_push($records, $fake);
$fake = new static(0, (object) [
'id' => 2,
'name' => "Ollama AI Provider",
'allowembeddings' => true,
'allowquery' => true,
'baseurl' => 'http://127.0.0.1:11434/api/',
'embeddings' => 'embeddings',
'embeddingmodel' => '',
'completions' => 'chat',
'completionmodel' => 'llama2',
'context' => null, // Global AI Provider
'onlyenrolledcourses' => true
// 'apikey'=> $_ENV['OPENAIKEY']
]);
array_push($records, $fake);
return $records;
Expand Down
5 changes: 4 additions & 1 deletion search/engine/solrrag/classes/ai/api.php
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

namespace core\ai;

require_once($CFG->dirroot . "/search/engine/solrrag/lib.php");

class api {

/**
Expand All @@ -15,7 +17,8 @@ public static function get_all_providers($context = null) {
}
public static function get_provider(int $id): AIProvider {
$fakes = AIProvider::get_records();
return $fakes[0];
return $fakes[0]; // Open AI
// return $fakes[1]; // Ollama

}
}
Loading

0 comments on commit f299b29

Please sign in to comment.