From f299b2909620f43703502768d5e53e7fcbdaf05e Mon Sep 17 00:00:00 2001 From: Michael Hughes Date: Tue, 19 Mar 2024 23:33:51 +0000 Subject: [PATCH] Vector search working for content indexed by solr --- mod/xaichat/templates/conversation.mustache | 3 + mod/xaichat/view.php | 56 +++--- search/engine/solrrag/classes/ai/aiclient.php | 27 +-- .../engine/solrrag/classes/ai/aiexception.php | 7 + .../engine/solrrag/classes/ai/aiprovider.php | 93 +++++++++- search/engine/solrrag/classes/ai/api.php | 5 +- search/engine/solrrag/classes/engine.php | 162 ++++++++++++++++-- search/engine/solrrag/lib.php | 6 + search/engine/solrrag/settings.php | 2 +- 9 files changed, 305 insertions(+), 56 deletions(-) create mode 100644 search/engine/solrrag/classes/ai/aiexception.php create mode 100644 search/engine/solrrag/lib.php diff --git a/mod/xaichat/templates/conversation.mustache b/mod/xaichat/templates/conversation.mustache index 7dacee0a4d149..cf653d4951c93 100644 --- a/mod/xaichat/templates/conversation.mustache +++ b/mod/xaichat/templates/conversation.mustache @@ -3,3 +3,6 @@ {{> mod_xaichat/message}} {{/messages}}
+{{#rawmessages}} +{{{rawmessages}}} +{{/rawmessages}} diff --git a/mod/xaichat/view.php b/mod/xaichat/view.php index b519113eb5135..cdb94b416f092 100644 --- a/mod/xaichat/view.php +++ b/mod/xaichat/view.php @@ -94,13 +94,12 @@ $progress->update(1, $totalsteps,'Looking for relevant context'); $vector = $aiclient->embed_query($data->userprompt); $search = \core_search\manager::instance(true, true); - $settings = [ - 'similarity' => true, - 'vector' => $vector, - ]; - $limit = 0; + // Some of these values can't be "trusted" to the end user to supply, via something + // like a form, nor can they be entirely left to the plugin developer. + $settings = $aiprovider->get_settings_for_user($USER); + $settings['vector'] = $vector; + $settings['userquery'] = $data->userprompt; $docs = $search->search((object)$settings); - var_dump($docs); // Perform "R" from RAG, finding documents from within the context that are similar to the user's prompt. // Add the retrieved documents to the context for this chat by generating some system messages with the content @@ -110,22 +109,22 @@ foreach ($docs as $doc) { $context[] = $doc->content; } - $_SESSION[$aicontextkey]['messages'][] = (object)[ + $prompt = (object)[ "role" => "system", - "content" => "Use the following context to answer questions:" . implode("\n",$context) + "content" => "Use the following context to answer following question:" . implode("\n",$context) ]; - } - - // Add the user's new prompt to the messages. + $_SESSION[$aicontextkey]['messages'][] = $prompt; + } $progress->update(2, $totalsteps,'Attaching user prompt'); - $_SESSION[$aicontextkey]['messages'][] = (object)[ + // $_SESSION[$aicontextkey]['messages'][] + $prompt = (object)[ "role" => "user", "content" => $data->userprompt ]; - // Pass the whole context over the AI to summarise. - var_dump($_SESSION[$aicontextkey]['messages']); + $_SESSION[$aicontextkey]['messages'][] = $prompt; + + // Pass the whole context over the AI to summarise. $progress->update(3, $totalsteps, 'Waiting for response'); - $airesults = $aiclient->chat($_SESSION[$aicontextkey]['messages']); $_SESSION[$aicontextkey]['messages'] = array_merge($_SESSION[$aicontextkey]['messages'],$airesults); $progress->update(4, $totalsteps, 'Got Response'); @@ -133,14 +132,19 @@ // We stash the data in the session temporarily (should go into an activity-user store in database) but this // is fast and dirty, and then we do a redirect so that we don't double up the request if the user hit's // refresh. - redirect(new \moodle_url('/mod/xaichat/view.php', ['id' => $cm->id])); + $next = new \moodle_url('/mod/xaichat/view.php', ['id' => $cm->id]); + redirect($next); } else if ($chatform->is_cancelled()) { $_SESSION[$aicontextkey] = [ 'messages'=>[] ]; + $prompt = (object)[ + "role" => "system", + "content" => "You are a helpful AI. You should only use information you know. Only use information that is relevant to the question." + ]; + $_SESSION[$aicontextkey]['messages'][] = $prompt; } else { // Clear session on first view of form. - $toform = [ 'id' => $id, 'aiproviderid' => $moduleinstance->aiproviderid, @@ -154,19 +158,27 @@ $displaymessages = []; foreach ($_SESSION[$aicontextkey]['messages'] as $message) { - $displaymessages[] = [ - "role" => $message->role == "user" ? $userpic : \html_writer::tag("strong", $aipic), - "content" => format_text($message->content, FORMAT_MARKDOWN) - ]; + if ($message->role != "system") { + $displaymessages[] = [ + "role" => $message->role == "user" ? $userpic : \html_writer::tag("strong", $aipic), + "content" => format_text($message->content, FORMAT_MARKDOWN) + ]; + } } +$displaymessages = array_reverse($displaymessages); $tcontext = [ "userpic" => new user_picture($USER), "messages" => $displaymessages ]; +$chatform->display(); echo $OUTPUT->render_from_template("mod_xaichat/conversation", $tcontext); -$chatform->display(); +if (false) { + echo html_writer::tag("pre", print_r($_SESSION[$aicontextkey]['messages'],1)); +} + + //echo \html_writer::tag('pre', print_r($displaymessages,1)); //echo \html_writer::tag('pre', print_r($_SESSION[$aicontextkey]['messages'],1)); diff --git a/search/engine/solrrag/classes/ai/aiclient.php b/search/engine/solrrag/classes/ai/aiclient.php index a9445e3614a5f..086c8d3a92645 100644 --- a/search/engine/solrrag/classes/ai/aiclient.php +++ b/search/engine/solrrag/classes/ai/aiclient.php @@ -1,6 +1,7 @@ libdir.'/filelib.php'); +use core\ai\AiException; /** * Base client for AI providers that uses simple http request. */ @@ -40,16 +41,20 @@ public function chat($messages) { $params = json_encode($params); $rawresult = $this->post($this->get_chat_completions_url(), $params); $jsonresult = json_decode($rawresult); - if (!isset($jsonresult->choices)) { - exit(); - return []; + if (isset($jsonresult->error)) { + throw new AiException("Error: " . $jsonresult->error->message . ":". print_r($messages, true)); + //return "Error: " . $jsonresult->error->message . ":". print_r($messages, true); } - $result = $this->convert_chat_completion($jsonresult->choices); - if (isset($jsonresult->usage)) { - $this->provider->increment_prompt_usage($jsonresult->usage->prompt_tokens); - $this->provider->increment_completion_tokens($jsonresult->usage->completion_tokens); - $this->provider->increment_total_tokens($jsonresult->usage->total_tokens); + $result = []; + if (isset($jsonresult->choices)) { + $result = $this->convert_chat_completion($jsonresult->choices); + if (isset($jsonresult->usage)) { + $this->provider->increment_prompt_usage($jsonresult->usage->prompt_tokens); + $this->provider->increment_completion_tokens($jsonresult->usage->completion_tokens); + $this->provider->increment_total_tokens($jsonresult->usage->total_tokens); + } } + return $result; } @@ -73,7 +78,7 @@ public function embed_query($content): array { // Send document to back end and return the vector $usedptokens = $this->provider->get_usage('prompt_tokens'); $totaltokens = $this->provider->get_usage('total_tokens'); - mtrace("Prompt tokens: $usedptokens. Total tokens: $totaltokens"); + // mtrace("Prompt tokens: $usedptokens. Total tokens: $totaltokens"); $params = [ "input" => htmlentities($content), // TODO need to do some length checking here! "model" => $this->provider->get('embeddingmodel') @@ -88,7 +93,7 @@ public function embed_query($content): array { $usage = $result['usage']; $this->provider->increment_prompt_usage($usage['prompt_tokens']); $this->provider->increment_total_tokens($usage['total_tokens']); - mtrace("Used Prompt tokens: {$usage['prompt_tokens']}. Total tokens: {$usage['total_tokens']}"); + // mtrace("Used Prompt tokens: {$usage['prompt_tokens']}. Total tokens: {$usage['total_tokens']}"); $data = $result['data']; foreach($data as $d) { if ($d['object'] == "embedding") { @@ -97,7 +102,7 @@ public function embed_query($content): array { } $usedptokens = $this->provider->get_usage('prompt_tokens'); $totaltokens = $this->provider->get_usage('total_tokens'); - mtrace("Total Used: Prompt tokens: $usedptokens. Total tokens: $totaltokens"); + // mtrace("Total Used: Prompt tokens: $usedptokens. Total tokens: $totaltokens"); return []; } public function embed_documents(array $documents) { diff --git a/search/engine/solrrag/classes/ai/aiexception.php b/search/engine/solrrag/classes/ai/aiexception.php new file mode 100644 index 0000000000000..ab006c952e12b --- /dev/null +++ b/search/engine/solrrag/classes/ai/aiexception.php @@ -0,0 +1,7 @@ +dirroot . "/search/engine/solrrag/lib.php"); + use \core\persistent; +use core_course_category; class AIProvider extends persistent { // Ultimately this would extend a persistent. - + const CONTEXT_ALL_MY_COURSES = -1; protected static function define_properties() { @@ -37,7 +39,17 @@ protected static function define_properties() ], 'completionmodel' => [ 'type' => PARAM_ALPHANUMEXT - ] + ], + // What context is this provider attached to. + // If null, it's a global provider. + // If -1 its limited to user's own courses. + 'context' => [ + 'type' => PARAM_INT + ], + // If true, only courses that the user is enrolled in will be searched. + 'onlyenrolledcourses' => [ + 'type' => PARAM_BOOL + ], ]; } @@ -95,6 +107,61 @@ public function increment_total_tokens($change) { set_config($key, $new, 'ai'); } + /** + * Returns appropriate search settings based on + * provider configuration. + */ + public function get_settings() { + // `userquery` and `vector` will be filled at run time. + $settings = [ + 'userquery'=> null, + 'vector' => null, + // `similarity` is a boolean that determines if the query should use vector similarity search. + 'similarity' => true, + 'areaids' => [], + // `excludeareaids` is an array of areaids that should be excluded from the search. + 'excludeareaids'=> ["core_user-user"], // <-- This may be should be in control of the AI Provider. + 'courseids' => [], // This of course IDs that result should be limited to. + ]; + return $settings; + } + + /** + * Gets user specific settings. + * + * This takes on some of the function that the manager code did. + */ + public function get_settings_for_user($user) { + $usersettings = $this->get_settings(); + + // This is basically manager::build_limitcourseids(). + $mycourseids = enrol_get_my_courses(array('id', 'cacherev'), 'id', 0, [], false); + $onlyenrolledcourses = $this->get('onlyenrolledcourses'); + $courseids = []; + if ($this->get('context') == self::CONTEXT_ALL_MY_COURSES) { + $courseids = array_keys($mycourseids); + } else { + $context = \context::instance_by_id($this->get('context')); + if ($context->contextlevel == CONTEXT_COURSE) { + // Check that the specific course is also in the user's list of courses. + $courseids = array_intersect([$context->instanceid], $mycourseids); + } else if ($context->contextlevel == CONTEXT_COURSECAT) { + // CourseIDs will be all courses in the category, + // optionally that the user is enrolled in + $category = core_course_category::get($context->instanceid); + $categorycourseids = $category->get_courses([ + 'recursive'=>true, + 'idonly' => true + ]); + } else if ($context->contextlevel == CONTEXT_SYSTEM) { + // No restrictions anywhere. + } + } + $usersettings['courseids'] = $courseids; + + return $usersettings; + } + //public function // TODO token counting. /** @@ -119,7 +186,25 @@ public static function get_records($filters = array(), $sort = '', $order = 'ASC 'embeddingmodel' => 'text-embedding-3-small', 'completions' => 'chat/completions', 'completionmodel' => 'gpt-4-turbo-preview', - 'apikey'=> $_ENV['OPENAIKEY'] + 'apikey'=> $_ENV['OPENAIKEY'], + 'context' => \context_system::instance()->id, + //null, // Global AI Provider + 'onlyenrolledcourses' => true + ]); + array_push($records, $fake); + $fake = new static(0, (object) [ + 'id' => 2, + 'name' => "Ollama AI Provider", + 'allowembeddings' => true, + 'allowquery' => true, + 'baseurl' => 'http://127.0.0.1:11434/api/', + 'embeddings' => 'embeddings', + 'embeddingmodel' => '', + 'completions' => 'chat', + 'completionmodel' => 'llama2', + 'context' => null, // Global AI Provider + 'onlyenrolledcourses' => true + // 'apikey'=> $_ENV['OPENAIKEY'] ]); array_push($records, $fake); return $records; diff --git a/search/engine/solrrag/classes/ai/api.php b/search/engine/solrrag/classes/ai/api.php index ad5939815746e..be8c4d627b68c 100644 --- a/search/engine/solrrag/classes/ai/api.php +++ b/search/engine/solrrag/classes/ai/api.php @@ -3,6 +3,8 @@ namespace core\ai; +require_once($CFG->dirroot . "/search/engine/solrrag/lib.php"); + class api { /** @@ -15,7 +17,8 @@ public static function get_all_providers($context = null) { } public static function get_provider(int $id): AIProvider { $fakes = AIProvider::get_records(); - return $fakes[0]; + return $fakes[0]; // Open AI + // return $fakes[1]; // Ollama } } diff --git a/search/engine/solrrag/classes/engine.php b/search/engine/solrrag/classes/engine.php index e7ece7e996c40..10a5cbcbf2405 100644 --- a/search/engine/solrrag/classes/engine.php +++ b/search/engine/solrrag/classes/engine.php @@ -4,12 +4,14 @@ use search_solrrag\document; use search_solrrag\schema; -// Fudge autoloading! -require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/api.php"); -require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiprovider.php"); -require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiclient.php"); +require_once($CFG->dirroot . "/search/engine/solrrag/lib.php"); +// // Fudge autoloading! +// require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/api.php"); +// require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiprovider.php"); +// require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiclient.php"); use \core\ai\AIProvider; use \core\ai\AIClient; +use \core\ai\AiException; class engine extends \search_solr\engine { /** @@ -79,7 +81,7 @@ public function add_document($document, $fileindexing = false) { $vlength = count($vector); $vectorfield = "solr_vector_" . $vlength; $docdata[$vectorfield] = $vector; - var_dump($docdata); + // var_dump($docdata); } else { debugging("Err didn't do any vector stuff!"); } @@ -107,7 +109,7 @@ public function add_document_batch(array $documents, bool $fileindexing = false) $vlength = count($vector); $vectorfield = "solr_vector_" . $vlength; $doc[$vectorfield] = $vector; - var_dump($doc); + // var_dump($doc); } else { debugging("Err didn't do any vector stuff!"); } @@ -338,37 +340,163 @@ public function execute_query($filters, $accessinfo, $limit = 0) { $filters->similarity ) { // Do a vector similarity search. - debugging("Running similarity search", DEBUG_DEVELOPER); - return $this->execute_solr_knn_query($filters, $accessinfo, $limit); + // debugging("Running similarity search", DEBUG_DEVELOPER); + // We may get accessinfo, but we actually should determine our own ones to apply too + // But we can't access the "manager" class' get_areas_user_accesses function, and + // that's already been called based on the configuration / data from the user + return $this->execute_similarity_query($filters, $accessinfo, $limit); } else { - debugging("Running regular search", DEBUG_DEVELOPER); - print_r($filters); - print_r($accessinfo); + // debugging("Running regular search", DEBUG_DEVELOPER); + // print_r($filters); + // print_r($accessinfo); return parent::execute_query($filters, $accessinfo, $limit); } } + /** + * A logging function just to allow us to output all the things + * that the process is doing for verification / validation. + * + * Probably not the most efficient way to do this, but Moodle's lacking + * a good generic logging framework. + * + * @param mixed $message An object/array/string that will be turned into a string. + */ + protected function log($message) { + $logfiledir = make_temp_directory('search_solrrag'); + $file = $logfiledir . '/solr_knn_query.log'; + $log = fopen($file, 'a'); + if (is_object($message)) { + $message = print_r($message, true); + } else if (is_array($message)) { + $message = print_r($message, true); + } + fwrite($log, date('Y-m-d H:i:s') . " " . $message . "\n"); + fclose($log); + } - public function execute_solr_knn_query($filters, $accessinfo, $limit) { + /** + * Perform a similarity search against the backend. + * + * This should be an optional method that can be implemented if the engine supports + * a vector search capability. + * + * This function will broadly replicate the same functionality as execute_query, but optimised + * for similarity + * + * @param \stdClass filters The filters object that contains the query and any other parameters. Basically from the search form. + * @param \stdClass accessinfo The access information for the user. + * @param int limit The maximum number of results to return. + */ + public function execute_similarity_query(\stdClass $filters, \stdClass $accessinfo, int $limit = null) { + $data = clone($filters); + $this->log("Executing SOLR KNN QUery"); + $this->log("Filters"); + $this->log($filters); $vector = $filters->vector; - $topK = 3; // Nearest neighbours to retrieve. - $field = "solr_vector_" . count($vector); - $requestbody = "{!knn f={$field} topK={$topK}}[" . implode(",", $vector) . "]"; + $topK = $limit > 0 ? $limit: 1; // We'll make the number of neighbours the same as search result limit. - $filters->mainquery = $requestbody; if (empty($limit)) { $limit = \core_search\manager::MAX_RESULTS; + $topK = \core_search\manager::MAX_RESULTS; // Nearest neighbours to retrieve. } + $field = "solr_vector_" . count($vector); + $requestbody = "{!knn f={$field} topK={$topK}}[" . implode(",", $vector) . "]"; + $this->log($requestbody); + $filters->mainquery = $requestbody; + // Build filter restrictions. + $filterqueries = []; + if(!empty($data->areaids)) { + $filterqueries[] = '{!cache=false}areaid:(' . implode(' OR ', $data->areaids) . ')'; + } + + if(!empty($data->excludeareaids)) { + $filterqueries[] = '{!cache=false}-areaid:(' . implode(' OR ', $data->excludeareaids) . ')'; + } + // Build access restrictions. + + // And finally restrict it to the context where the user can access, we want this one cached. + // If the user can access all contexts $usercontexts value is just true, we don't need to filter + // in that case. + if (!$accessinfo->everything && is_array($accessinfo->usercontexts)) { + // Join all area contexts into a single array and implode. + $allcontexts = array(); + foreach ($accessinfo->usercontexts as $areaid => $areacontexts) { + if (!empty($data->areaids) && !in_array($areaid, $data->areaids)) { + // Skip unused areas. + continue; + } + foreach ($areacontexts as $contextid) { + // Ensure they are unique. + $allcontexts[$contextid] = $contextid; + } + } + if (empty($allcontexts)) { + // This means there are no valid contexts for them, so they get no results. + return null; + } + $filterqueries[] = 'contextid:(' . implode(' OR ', $allcontexts) . ')'; + } + + if (!$accessinfo->everything && $accessinfo->separategroupscontexts) { + // Add another restriction to handle group ids. If there are any contexts using separate + // groups, then results in that context will not show unless you belong to the group. + // (Note: Access all groups is taken care of earlier, when computing these arrays.) + + // This special exceptions list allows for particularly pig-headed developers to create + // multiple search areas within the same module, where one of them uses separate + // groups and the other uses visible groups. It is a little inefficient, but this should + // be rare. + $exceptions = ''; + if ($accessinfo->visiblegroupscontextsareas) { + foreach ($accessinfo->visiblegroupscontextsareas as $contextid => $areaids) { + $exceptions .= ' OR (contextid:' . $contextid . ' AND areaid:(' . + implode(' OR ', $areaids) . '))'; + } + } + + if ($accessinfo->usergroups) { + // Either the document has no groupid, or the groupid is one that the user + // belongs to, or the context is not one of the separate groups contexts. + $filterqueries[] = '(*:* -groupid:[* TO *]) OR ' . + 'groupid:(' . implode(' OR ', $accessinfo->usergroups) . ') OR ' . + '(*:* -contextid:(' . implode(' OR ', $accessinfo->separategroupscontexts) . '))' . + $exceptions; + } else { + // Either the document has no groupid, or the context is not a restricted one. + $filterqueries[] = '(*:* -groupid:[* TO *]) OR ' . + '(*:* -contextid:(' . implode(' OR ', $accessinfo->separategroupscontexts) . '))' . + $exceptions; + } + } + + if ($this->file_indexing_enabled()) { + // Now group records by solr_filegroupingid. Limit to 3 results per group. + // TODO work out how to convert the following into query / filter parameters. + // $query->setGroup(true); + // $query->setGroupLimit(3); + // $query->setGroupNGroups(true); + // $query->addGroupField('solr_filegroupingid'); + } else { + // Make sure we only get text files, in case the index has pre-existing files. + $filterqueries[] = 'type:'.\core_search\manager::TYPE_TEXT; + } + + // Finally perform the actaul search + $curl = $this->get_curl_object(); $requesturl = $this->get_connection_url('/select'); $requesturl->param('fl', 'id,areaid,score,content'); $requesturl->param('wt', 'xml'); - // $requesturl->param('query', $requestbody) + $requesturl->param('fq', implode("&", $filterqueries)); + $params = [ "query" => $requestbody, ]; + $curl->setHeader('Content-type: application/json'); $result = $curl->post($requesturl->out(false), json_encode($params)); + $this->log($result); // Probably have to duplicate error handling code from the add_stored_file() function. $code = $curl->get_errno(); diff --git a/search/engine/solrrag/lib.php b/search/engine/solrrag/lib.php new file mode 100644 index 0000000000000..46f51c2d8e667 --- /dev/null +++ b/search/engine/solrrag/lib.php @@ -0,0 +1,6 @@ +dirroot ."/search/engine/solrrag/classes/ai/api.php"); +require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiprovider.php"); +require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiclient.php"); +require_once($CFG->dirroot ."/search/engine/solrrag/classes/ai/aiexception.php"); \ No newline at end of file diff --git a/search/engine/solrrag/settings.php b/search/engine/solrrag/settings.php index 85b427e5c75be..b4a837c5e0e44 100644 --- a/search/engine/solrrag/settings.php +++ b/search/engine/solrrag/settings.php @@ -23,7 +23,7 @@ */ defined('MOODLE_INTERNAL') || die(); - +require_once($CFG->dirroot . "/search/engine/solrrag/lib.php"); if ($ADMIN->fulltree) { if (!during_initial_install()) {