From 29b657cfb439eb1610a4b24704a202a7cbb73df2 Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Tue, 26 Nov 2024 11:41:50 -0500 Subject: [PATCH] WIP Code interpreter tool call. Signed-off-by: Adam Treat --- gpt4all-chat/CMakeLists.txt | 8 +- gpt4all-chat/qml/ChatCollapsibleItem.qml | 160 +++++++ gpt4all-chat/qml/ChatItemView.qml | 169 ++------ gpt4all-chat/qml/ChatTextItem.qml | 139 ++++++ gpt4all-chat/qml/ChatView.qml | 2 + gpt4all-chat/src/chat.cpp | 71 +++- gpt4all-chat/src/chat.h | 4 +- gpt4all-chat/src/chatllm.cpp | 102 +++-- gpt4all-chat/src/chatllm.h | 12 +- gpt4all-chat/src/chatmodel.cpp | 17 + gpt4all-chat/src/chatmodel.h | 516 ++++++++++++++++++----- gpt4all-chat/src/codeinterpreter.cpp | 125 ++++++ gpt4all-chat/src/codeinterpreter.h | 83 ++++ gpt4all-chat/src/jinja_helpers.cpp | 24 +- gpt4all-chat/src/jinja_helpers.h | 6 +- gpt4all-chat/src/main.cpp | 3 + gpt4all-chat/src/server.cpp | 16 +- gpt4all-chat/src/tool.cpp | 41 ++ gpt4all-chat/src/tool.h | 122 ++++++ gpt4all-chat/src/toolcallparser.cpp | 107 +++++ gpt4all-chat/src/toolcallparser.h | 45 ++ gpt4all-chat/src/toolmodel.cpp | 31 ++ gpt4all-chat/src/toolmodel.h | 104 +++++ 23 files changed, 1603 insertions(+), 304 deletions(-) create mode 100644 gpt4all-chat/qml/ChatCollapsibleItem.qml create mode 100644 gpt4all-chat/qml/ChatTextItem.qml create mode 100644 gpt4all-chat/src/chatmodel.cpp create mode 100644 gpt4all-chat/src/codeinterpreter.cpp create mode 100644 gpt4all-chat/src/codeinterpreter.h create mode 100644 gpt4all-chat/src/tool.cpp create mode 100644 gpt4all-chat/src/tool.h create mode 100644 gpt4all-chat/src/toolcallparser.cpp create mode 100644 gpt4all-chat/src/toolcallparser.h create mode 100644 gpt4all-chat/src/toolmodel.cpp create mode 100644 gpt4all-chat/src/toolmodel.h diff --git a/gpt4all-chat/CMakeLists.txt b/gpt4all-chat/CMakeLists.txt index 257338b9510c..f5c678dd6a19 100644 --- a/gpt4all-chat/CMakeLists.txt +++ b/gpt4all-chat/CMakeLists.txt @@ -185,8 +185,9 @@ qt_add_executable(chat src/chatapi.cpp src/chatapi.h src/chatlistmodel.cpp src/chatlistmodel.h src/chatllm.cpp src/chatllm.h - src/chatmodel.h + src/chatmodel.cpp src/chatmodel.h src/chatviewtextprocessor.cpp src/chatviewtextprocessor.h + src/codeinterpreter.cpp src/codeinterpreter.h src/database.cpp src/database.h src/download.cpp src/download.h src/embllm.cpp src/embllm.h @@ -199,6 +200,9 @@ qt_add_executable(chat src/mysettings.cpp src/mysettings.h src/network.cpp src/network.h src/server.cpp src/server.h + src/tool.cpp src/tool.h + src/toolcallparser.cpp src/toolcallparser.h + src/toolmodel.cpp src/toolmodel.h src/xlsxtomd.cpp src/xlsxtomd.h ${CHAT_EXE_RESOURCES} ${MACOS_SOURCES} @@ -215,8 +219,10 @@ qt_add_qml_module(chat qml/AddModelView.qml qml/ApplicationSettings.qml qml/ChatDrawer.qml + qml/ChatCollapsibleItem.qml qml/ChatItemView.qml qml/ChatMessageButton.qml + qml/ChatTextItem.qml qml/ChatView.qml qml/CollectionsDrawer.qml qml/HomeView.qml diff --git a/gpt4all-chat/qml/ChatCollapsibleItem.qml b/gpt4all-chat/qml/ChatCollapsibleItem.qml new file mode 100644 index 000000000000..4ff01511bf9b --- /dev/null +++ b/gpt4all-chat/qml/ChatCollapsibleItem.qml @@ -0,0 +1,160 @@ +import Qt5Compat.GraphicalEffects +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts + +import gpt4all +import mysettings +import toolenums + +ColumnLayout { + property alias textContent: innerTextItem.textContent + property bool isCurrent: false + property bool isError: false + + Layout.topMargin: 10 + Layout.bottomMargin: 10 + + Item { + Layout.preferredWidth: childrenRect.width + Layout.preferredHeight: 38 + RowLayout { + anchors.left: parent.left + anchors.top: parent.top + anchors.bottom: parent.bottom + + Item { + width: myTextArea.width + height: myTextArea.height + TextArea { + id: myTextArea + text: { + if (isError) + return qsTr("Analysis encountered error"); + if (isCurrent) + return qsTr("Analyzing"); + return qsTr("Analyzed"); + } + padding: 0 + font.pixelSize: theme.fontSizeLarger + enabled: false + focus: false + readOnly: true + color: headerMA.containsMouse ? theme.mutedDarkTextColorHovered : theme.mutedTextColor + hoverEnabled: false + } + + Item { + id: textColorOverlay + anchors.fill: parent + clip: true + visible: false + Rectangle { + id: animationRec + width: myTextArea.width * 0.3 + anchors.top: parent.top + anchors.bottom: parent.bottom + color: theme.textColor + + SequentialAnimation { + running: isCurrent + loops: Animation.Infinite + NumberAnimation { + target: animationRec; + property: "x"; + from: -animationRec.width; + to: myTextArea.width * 3; + duration: 2000 + } + } + } + } + OpacityMask { + visible: isCurrent + anchors.fill: parent + maskSource: myTextArea + source: textColorOverlay + } + } + + Item { + id: caret + Layout.preferredWidth: contentCaret.width + Layout.preferredHeight: contentCaret.height + Image { + id: contentCaret + anchors.centerIn: parent + visible: false + sourceSize.width: theme.fontSizeLarge + sourceSize.height: theme.fontSizeLarge + mipmap: true + source: { + if (contentLayout.state === "collapsed") + return "qrc:/gpt4all/icons/caret_right.svg"; + else + return "qrc:/gpt4all/icons/caret_down.svg"; + } + } + + ColorOverlay { + anchors.fill: contentCaret + source: contentCaret + color: headerMA.containsMouse ? theme.mutedDarkTextColorHovered : theme.mutedTextColor + } + } + } + + MouseArea { + id: headerMA + hoverEnabled: true + anchors.fill: parent + onClicked: { + if (contentLayout.state === "collapsed") + contentLayout.state = "expanded"; + else + contentLayout.state = "collapsed"; + } + } + } + + ColumnLayout { + id: contentLayout + spacing: 0 + state: "collapsed" + clip: true + + states: [ + State { + name: "expanded" + PropertyChanges { target: contentLayout; Layout.preferredHeight: innerContentLayout.height } + }, + State { + name: "collapsed" + PropertyChanges { target: contentLayout; Layout.preferredHeight: 0 } + } + ] + + transitions: [ + Transition { + SequentialAnimation { + PropertyAnimation { + target: contentLayout + property: "Layout.preferredHeight" + duration: 300 + easing.type: Easing.InOutQuad + } + } + } + ] + + ColumnLayout { + id: innerContentLayout + Layout.leftMargin: 30 + ChatTextItem { + id: innerTextItem + } + } + } +} \ No newline at end of file diff --git a/gpt4all-chat/qml/ChatItemView.qml b/gpt4all-chat/qml/ChatItemView.qml index ed7476149ddc..e9c5f7fee83c 100644 --- a/gpt4all-chat/qml/ChatItemView.qml +++ b/gpt4all-chat/qml/ChatItemView.qml @@ -4,9 +4,11 @@ import QtQuick import QtQuick.Controls import QtQuick.Controls.Basic import QtQuick.Layouts +import Qt.labs.qmlmodels import gpt4all import mysettings +import toolenums ColumnLayout { @@ -33,6 +35,9 @@ GridLayout { Layout.alignment: Qt.AlignVCenter | Qt.AlignRight Layout.preferredWidth: 32 Layout.preferredHeight: 32 + Layout.topMargin: model.index > 0 ? 25 : 0 + visible: content !== "" || childItems.length > 0 + Image { id: logo sourceSize: Qt.size(32, 32) @@ -65,6 +70,9 @@ GridLayout { Layout.column: 1 Layout.fillWidth: true Layout.preferredHeight: 38 + Layout.topMargin: model.index > 0 ? 25 : 0 + visible: content !== "" || childItems.length > 0 + RowLayout { spacing: 5 anchors.left: parent.left @@ -72,7 +80,11 @@ GridLayout { anchors.bottom: parent.bottom TextArea { - text: name === "Response: " ? qsTr("GPT4All") : qsTr("You") + text: { + if (name === "Response: ") + return qsTr("GPT4All"); + return qsTr("You"); + } padding: 0 font.pixelSize: theme.fontSizeLarger font.bold: true @@ -88,7 +100,7 @@ GridLayout { color: theme.mutedTextColor } RowLayout { - visible: isCurrentResponse && (value === "" && currentChat.responseInProgress) + visible: isCurrentResponse && (content === "" && currentChat.responseInProgress) Text { color: theme.mutedTextColor font.pixelSize: theme.fontSizeLarger @@ -156,131 +168,36 @@ GridLayout { } } - TextArea { - id: myTextArea - Layout.fillWidth: true - padding: 0 - color: { - if (!currentChat.isServer) - return theme.textColor - return theme.white - } - wrapMode: Text.WordWrap - textFormat: TextEdit.PlainText - focus: false - readOnly: true - font.pixelSize: theme.fontSizeLarge - cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false - cursorPosition: text.length - TapHandler { - id: tapHandler - onTapped: function(eventPoint, button) { - var clickedPos = myTextArea.positionAt(eventPoint.position.x, eventPoint.position.y); - var success = textProcessor.tryCopyAtPosition(clickedPos); - if (success) - copyCodeMessage.open(); - } - } - - MouseArea { - id: conversationMouseArea - anchors.fill: parent - acceptedButtons: Qt.RightButton - - onClicked: (mouse) => { - if (mouse.button === Qt.RightButton) { - conversationContextMenu.x = conversationMouseArea.mouseX - conversationContextMenu.y = conversationMouseArea.mouseY - conversationContextMenu.open() - } - } - } - - onLinkActivated: function(link) { - if (!isCurrentResponse || !currentChat.responseInProgress) - Qt.openUrlExternally(link) - } - - onLinkHovered: function (link) { - if (!isCurrentResponse || !currentChat.responseInProgress) - statusBar.externalHoveredLink = link - } - - MyMenu { - id: conversationContextMenu - MyMenuItem { - text: qsTr("Copy") - enabled: myTextArea.selectedText !== "" - height: enabled ? implicitHeight : 0 - onTriggered: myTextArea.copy() - } - MyMenuItem { - text: qsTr("Copy Message") - enabled: myTextArea.selectedText === "" - height: enabled ? implicitHeight : 0 - onTriggered: { - myTextArea.selectAll() - myTextArea.copy() - myTextArea.deselect() + Repeater { + model: childItems + + DelegateChooser { + id: chooser + role: "name" + DelegateChoice { + roleValue: "Text: "; + ChatTextItem { + Layout.fillWidth: true + textContent: modelData.content } } - MyMenuItem { - text: textProcessor.shouldProcessText ? qsTr("Disable markdown") : qsTr("Enable markdown") - height: enabled ? implicitHeight : 0 - onTriggered: { - textProcessor.shouldProcessText = !textProcessor.shouldProcessText; - textProcessor.setValue(value); + DelegateChoice { + roleValue: "ToolCall: "; + ChatCollapsibleItem { + Layout.fillWidth: true + textContent: modelData.content + isCurrent: modelData.isCurrentResponse + isError: modelData.isToolCallError } } } - ChatViewTextProcessor { - id: textProcessor - } - - function resetChatViewTextProcessor() { - textProcessor.fontPixelSize = myTextArea.font.pixelSize - textProcessor.codeColors.defaultColor = theme.codeDefaultColor - textProcessor.codeColors.keywordColor = theme.codeKeywordColor - textProcessor.codeColors.functionColor = theme.codeFunctionColor - textProcessor.codeColors.functionCallColor = theme.codeFunctionCallColor - textProcessor.codeColors.commentColor = theme.codeCommentColor - textProcessor.codeColors.stringColor = theme.codeStringColor - textProcessor.codeColors.numberColor = theme.codeNumberColor - textProcessor.codeColors.headerColor = theme.codeHeaderColor - textProcessor.codeColors.backgroundColor = theme.codeBackgroundColor - textProcessor.textDocument = textDocument - textProcessor.setValue(value); - } - - property bool textProcessorReady: false - - Component.onCompleted: { - resetChatViewTextProcessor(); - textProcessorReady = true; - } - - Connections { - target: chatModel - function onValueChanged(i, value) { - if (myTextArea.textProcessorReady && index === i) - textProcessor.setValue(value); - } - } - - Connections { - target: MySettings - function onFontSizeChanged() { - myTextArea.resetChatViewTextProcessor(); - } - function onChatThemeChanged() { - myTextArea.resetChatViewTextProcessor(); - } - } + delegate: chooser + } - Accessible.role: Accessible.Paragraph - Accessible.name: text - Accessible.description: name === "Response: " ? "The response by the model" : "The prompt by the user" + ChatTextItem { + Layout.fillWidth: true + textContent: content } ThumbsDownDialog { @@ -289,16 +206,16 @@ GridLayout { y: Math.round((parent.height - height) / 2) width: 640 height: 300 - property string text: value + property string text: content response: newResponse === undefined || newResponse === "" ? text : newResponse onAccepted: { var responseHasChanged = response !== text && response !== newResponse if (thumbsDownState && !thumbsUpState && !responseHasChanged) return - chatModel.updateNewResponse(index, response) - chatModel.updateThumbsUpState(index, false) - chatModel.updateThumbsDownState(index, true) + chatModel.updateNewResponse(model.index, response) + chatModel.updateThumbsUpState(model.index, false) + chatModel.updateThumbsDownState(model.index, true) Network.sendConversation(currentChat.id, getConversationJson()); } } @@ -416,7 +333,7 @@ GridLayout { states: [ State { name: "expanded" - PropertyChanges { target: sourcesLayout; Layout.preferredHeight: flow.height } + PropertyChanges { target: sourcesLayout; Layout.preferredHeight: sourcesFlow.height } }, State { name: "collapsed" @@ -438,7 +355,7 @@ GridLayout { ] Flow { - id: flow + id: sourcesFlow Layout.fillWidth: true spacing: 10 visible: consolidatedSources.length !== 0 diff --git a/gpt4all-chat/qml/ChatTextItem.qml b/gpt4all-chat/qml/ChatTextItem.qml new file mode 100644 index 000000000000..e316bf1ce7bc --- /dev/null +++ b/gpt4all-chat/qml/ChatTextItem.qml @@ -0,0 +1,139 @@ +import Qt5Compat.GraphicalEffects +import QtCore +import QtQuick +import QtQuick.Controls +import QtQuick.Controls.Basic +import QtQuick.Layouts + +import gpt4all +import mysettings +import toolenums + +TextArea { + id: myTextArea + property string textContent: "" + visible: textContent != "" + Layout.fillWidth: true + padding: 0 + color: { + if (!currentChat.isServer) + return theme.textColor + return theme.white + } + wrapMode: Text.WordWrap + textFormat: TextEdit.PlainText + focus: false + readOnly: true + font.pixelSize: theme.fontSizeLarge + cursorVisible: isCurrentResponse ? currentChat.responseInProgress : false + cursorPosition: text.length + TapHandler { + id: tapHandler + onTapped: function(eventPoint, button) { + var clickedPos = myTextArea.positionAt(eventPoint.position.x, eventPoint.position.y); + var success = textProcessor.tryCopyAtPosition(clickedPos); + if (success) + copyCodeMessage.open(); + } + } + + MouseArea { + id: conversationMouseArea + anchors.fill: parent + acceptedButtons: Qt.RightButton + + onClicked: (mouse) => { + if (mouse.button === Qt.RightButton) { + conversationContextMenu.x = conversationMouseArea.mouseX + conversationContextMenu.y = conversationMouseArea.mouseY + conversationContextMenu.open() + } + } + } + + onLinkActivated: function(link) { + if (!isCurrentResponse || !currentChat.responseInProgress) + Qt.openUrlExternally(link) + } + + onLinkHovered: function (link) { + if (!isCurrentResponse || !currentChat.responseInProgress) + statusBar.externalHoveredLink = link + } + + MyMenu { + id: conversationContextMenu + MyMenuItem { + text: qsTr("Copy") + enabled: myTextArea.selectedText !== "" + height: enabled ? implicitHeight : 0 + onTriggered: myTextArea.copy() + } + MyMenuItem { + text: qsTr("Copy Message") + enabled: myTextArea.selectedText === "" + height: enabled ? implicitHeight : 0 + onTriggered: { + myTextArea.selectAll() + myTextArea.copy() + myTextArea.deselect() + } + } + MyMenuItem { + text: textProcessor.shouldProcessText ? qsTr("Disable markdown") : qsTr("Enable markdown") + height: enabled ? implicitHeight : 0 + onTriggered: { + textProcessor.shouldProcessText = !textProcessor.shouldProcessText; + textProcessor.setValue(textContent); + } + } + } + + ChatViewTextProcessor { + id: textProcessor + } + + function resetChatViewTextProcessor() { + textProcessor.fontPixelSize = myTextArea.font.pixelSize + textProcessor.codeColors.defaultColor = theme.codeDefaultColor + textProcessor.codeColors.keywordColor = theme.codeKeywordColor + textProcessor.codeColors.functionColor = theme.codeFunctionColor + textProcessor.codeColors.functionCallColor = theme.codeFunctionCallColor + textProcessor.codeColors.commentColor = theme.codeCommentColor + textProcessor.codeColors.stringColor = theme.codeStringColor + textProcessor.codeColors.numberColor = theme.codeNumberColor + textProcessor.codeColors.headerColor = theme.codeHeaderColor + textProcessor.codeColors.backgroundColor = theme.codeBackgroundColor + textProcessor.textDocument = textDocument + textProcessor.setValue(textContent); + } + + property bool textProcessorReady: false + + Component.onCompleted: { + resetChatViewTextProcessor(); + textProcessorReady = true; + } + + Connections { + target: myTextArea + function onTextContentChanged() { + if (myTextArea.textProcessorReady) + textProcessor.setValue(textContent); + } + } + + Connections { + target: MySettings + function onFontSizeChanged() { + myTextArea.resetChatViewTextProcessor(); + } + function onChatThemeChanged() { + myTextArea.resetChatViewTextProcessor(); + } + } + + Accessible.role: Accessible.Paragraph + Accessible.name: text + Accessible.description: name === "Response: " ? "The response by the model" : "The prompt by the user" +} \ No newline at end of file diff --git a/gpt4all-chat/qml/ChatView.qml b/gpt4all-chat/qml/ChatView.qml index b8a5b27b2d6c..f28eef39c4eb 100644 --- a/gpt4all-chat/qml/ChatView.qml +++ b/gpt4all-chat/qml/ChatView.qml @@ -823,6 +823,8 @@ Rectangle { textInput.forceActiveFocus(); textInput.cursorPosition = text.length; } + height: visible ? implicitHeight : 0 + visible: name !== "ToolResponse: " } function scrollToEnd() { diff --git a/gpt4all-chat/src/chat.cpp b/gpt4all-chat/src/chat.cpp index c40bb96ed34b..51ecb9e966b6 100644 --- a/gpt4all-chat/src/chat.cpp +++ b/gpt4all-chat/src/chat.cpp @@ -3,10 +3,16 @@ #include "chatlistmodel.h" #include "network.h" #include "server.h" +#include "tool.h" +#include "toolcallparser.h" +#include "toolmodel.h" #include #include #include +#include +#include +#include #include #include #include @@ -16,6 +22,8 @@ #include +using namespace ToolEnums; + Chat::Chat(QObject *parent) : QObject(parent) , m_id(Network::globalInstance()->generateUniqueId()) @@ -54,7 +62,6 @@ void Chat::connectLLM() // Should be in different threads connect(m_llmodel, &ChatLLM::modelLoadingPercentageChanged, this, &Chat::handleModelLoadingPercentageChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection); - connect(m_llmodel, &ChatLLM::responseFailed, this, &Chat::handleResponseFailed, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection); @@ -182,23 +189,12 @@ Chat::ResponseState Chat::responseState() const return m_responseState; } -void Chat::handleResponseChanged(const QString &response) +void Chat::handleResponseChanged() { if (m_responseState != Chat::ResponseGeneration) { m_responseState = Chat::ResponseGeneration; emit responseStateChanged(); } - - const int index = m_chatModel->count() - 1; - m_chatModel->updateValue(index, response); -} - -void Chat::handleResponseFailed(const QString &error) -{ - const int index = m_chatModel->count() - 1; - m_chatModel->updateValue(index, error); - m_chatModel->setError(); - responseStopped(0); } void Chat::handleModelLoadingPercentageChanged(float loadingPercentage) @@ -243,9 +239,58 @@ void Chat::responseStopped(qint64 promptResponseMs) m_responseState = Chat::ResponseStopped; emit responseInProgressChanged(); emit responseStateChanged(); + + const int index = m_chatModel->count() - 1; + ChatItem *item = m_chatModel->get(index); + + // FIXME + const QString possibleToolcall = item->toolCallValue(); + + ToolCallParser parser; + parser.update(possibleToolcall); + + if (item->type() == ChatItem::Type::Response && parser.state() == ToolEnums::ParseState::Complete) { + const QString toolCall = parser.toolCall(); + + // Regex to remove the formatting around the code + static const QRegularExpression regex("^\\s*```javascript\\s*|\\s*```\\s*$"); + QString code = toolCall; + code.remove(regex); + code = code.trimmed(); + + // Right now the code interpreter is the only available tool + Tool *toolInstance = ToolModel::globalInstance()->get(ToolCallConstants::CodeInterpreterFunction); + Q_ASSERT(toolInstance); + + // The param is the code + const ToolParam param = { "code", ToolEnums::ParamType::String, code }; + const QString result = toolInstance->run({param}, 10000 /*msecs to timeout*/); + const ToolEnums::Error error = toolInstance->error(); + const QString errorString = toolInstance->errorString(); + + // Update the current response with meta information about toolcall and re-parent + m_chatModel->updateToolCall({ + ToolCallConstants::CodeInterpreterFunction, + { param }, + result, + error, + errorString + }); + + ++m_consecutiveToolCalls; + + // We limit the number of consecutive toolcalls otherwise we get into a potentially endless loop + if (m_consecutiveToolCalls < 3 || error == ToolEnums::Error::NoError) { + resetResponseState(); + emit promptRequested(m_collections); // triggers a new response + return; + } + } + if (m_generatedName.isEmpty()) emit generateNameRequested(); + m_consecutiveToolCalls = 0; Network::globalInstance()->trackChatEvent("response_complete", { {"first", m_firstResponse}, {"message_count", chatModel()->count()}, diff --git a/gpt4all-chat/src/chat.h b/gpt4all-chat/src/chat.h index 57e413e5873d..dc8f3e180b35 100644 --- a/gpt4all-chat/src/chat.h +++ b/gpt4all-chat/src/chat.h @@ -161,8 +161,7 @@ public Q_SLOTS: void generatedQuestionsChanged(); private Q_SLOTS: - void handleResponseChanged(const QString &response); - void handleResponseFailed(const QString &error); + void handleResponseChanged(); void handleModelLoadingPercentageChanged(float); void promptProcessing(); void generatingQuestions(); @@ -205,6 +204,7 @@ private Q_SLOTS: // - The chat was freshly created during this launch. // - The chat was changed after loading it from disk. bool m_needsSave = true; + int m_consecutiveToolCalls = 0; }; #endif // CHAT_H diff --git a/gpt4all-chat/src/chatllm.cpp b/gpt4all-chat/src/chatllm.cpp index 7841b9460e99..c8bb57d9c789 100644 --- a/gpt4all-chat/src/chatllm.cpp +++ b/gpt4all-chat/src/chatllm.cpp @@ -7,6 +7,9 @@ #include "localdocs.h" #include "mysettings.h" #include "network.h" +#include "tool.h" +#include "toolmodel.h" +#include "toolcallparser.h" #include @@ -54,6 +57,7 @@ #include using namespace Qt::Literals::StringLiterals; +using namespace ToolEnums; namespace ranges = std::ranges; //#define DEBUG @@ -638,6 +642,7 @@ bool isAllSpace(R &&r) void ChatLLM::regenerateResponse(int index) { Q_ASSERT(m_chatModel); +#if 0 // FIXME int promptIdx; { auto items = m_chatModel->chatItems(); // holds lock @@ -656,7 +661,7 @@ void ChatLLM::regenerateResponse(int index) m_chatModel->setError(false); if (promptIdx >= 0) m_chatModel->updateSources(promptIdx, {}); - +#endif prompt(m_chat->collectionList()); } @@ -665,10 +670,10 @@ std::optional ChatLLM::popPrompt(int index) Q_ASSERT(m_chatModel); QString content; { - auto items = m_chatModel->chatItems(); // holds lock - if (index < 0 || index >= items.size() || items[index].type() != ChatItem::Type::Prompt) + auto items = m_chatModel->messageItems(); // holds lock + if (index < 0 || index >= items.size() || items[index].type() != MessageItem::Type::Prompt) return std::nullopt; - content = items[index].value; + content = items[index].content(); } m_chatModel->truncate(index); return content; @@ -732,7 +737,8 @@ void ChatLLM::prompt(const QStringList &enabledCollections) promptInternalChat(enabledCollections, promptContextFromSettings(m_modelInfo)); } catch (const std::exception &e) { // FIXME(jared): this is neither translated nor serialized - emit responseFailed(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); + m_chatModel->setResponseValue(u"Error: %1"_s.arg(QString::fromUtf8(e.what()))); + m_chatModel->setError(); emit responseStopped(0); } } @@ -740,20 +746,20 @@ void ChatLLM::prompt(const QStringList &enabledCollections) // FIXME(jared): We can avoid this potentially expensive copy if we use ChatItem pointers, but this is only safe if we // hold the lock while generating. We can't do that now because Chat is actually in charge of updating the response, not // ChatLLM. -std::vector ChatLLM::forkConversation(const QString &prompt) const +std::vector ChatLLM::forkConversation(const QString &prompt) const { Q_ASSERT(m_chatModel); if (m_chatModel->hasError()) throw std::logic_error("cannot continue conversation with an error"); - std::vector conversation; + std::vector conversation; { - auto items = m_chatModel->chatItems(); // holds lock + auto items = m_chatModel->messageItems(); Q_ASSERT(items.size() >= 2); // should be prompt/response pairs conversation.reserve(items.size() + 1); conversation.assign(items.begin(), items.end()); } - conversation.emplace_back(ChatItem::prompt_tag, prompt); + conversation.emplace_back(MessageItem::Type::Prompt, prompt.toUtf8()); return conversation; } @@ -788,7 +794,7 @@ std::optional ChatLLM::checkJinjaTemplateError(const std::string &s return std::nullopt; } -std::string ChatLLM::applyJinjaTemplate(std::span items) const +std::string ChatLLM::applyJinjaTemplate(std::span items) const { Q_ASSERT(items.size() >= 1); @@ -815,25 +821,33 @@ std::string ChatLLM::applyJinjaTemplate(std::span items) const uint version = parseJinjaTemplateVersion(chatTemplate); - auto makeMap = [version](const ChatItem &item) { + auto makeMap = [version](const MessageItem &item) { return jinja2::GenericMap([msg = std::make_shared(version, item)] { return msg.get(); }); }; - std::unique_ptr systemItem; + std::unique_ptr systemItem; bool useSystem = !isAllSpace(systemMessage); jinja2::ValuesList messages; messages.reserve(useSystem + items.size()); if (useSystem) { - systemItem = std::make_unique(ChatItem::system_tag, systemMessage); + systemItem = std::make_unique(MessageItem::Type::System, systemMessage.toUtf8()); messages.emplace_back(makeMap(*systemItem)); } for (auto &item : items) messages.emplace_back(makeMap(item)); + jinja2::ValuesList toolList; + const int toolCount = ToolModel::globalInstance()->count(); + for (int i = 0; i < toolCount; ++i) { + Tool *t = ToolModel::globalInstance()->get(i); + toolList.push_back(t->jinjaValue()); + } + jinja2::ValuesMap params { { "messages", std::move(messages) }, { "add_generation_prompt", true }, + { "toolList", toolList }, }; for (auto &[name, token] : model->specialTokens()) params.emplace(std::move(name), std::move(token)); @@ -857,12 +871,14 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL if (!enabledCollections.isEmpty()) { std::optional> query; { +#if 0 // FIXME // Find the prompt that represents the query. Server chats are flexible and may not have one. auto items = m_chatModel->chatItems(); // holds lock Q_ASSERT(items); auto response = items.end() - 1; if (auto peer = m_chatModel->getPeerUnlocked(response)) query = {*peer - items.begin(), (*peer)->value}; +#endif } if (query) { auto &[promptIndex, queryStr] = *query; @@ -873,13 +889,13 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL } // copy messages for safety (since we can't hold the lock the whole time) - std::vector chatItems; + std::vector messageItems; { - auto items = m_chatModel->chatItems(); // holds lock + auto items = m_chatModel->messageItems(); Q_ASSERT(items.size() >= 2); // should be prompt/response pairs - chatItems.assign(items.begin(), items.end() - 1); // exclude last + messageItems.assign(items.begin(), items.end() - 1); // exclude last } - auto result = promptInternal(chatItems, ctx, !databaseResults.isEmpty()); + auto result = promptInternal(messageItems, ctx, !databaseResults.isEmpty()); return { /*PromptResult*/ { .response = std::move(result.response), @@ -891,7 +907,7 @@ auto ChatLLM::promptInternalChat(const QStringList &enabledCollections, const LL } auto ChatLLM::promptInternal( - const std::variant, std::string_view> &prompt, + const std::variant, std::string_view> &prompt, const LLModel::PromptContext &ctx, bool usedLocalDocs ) -> PromptResult @@ -901,14 +917,14 @@ auto ChatLLM::promptInternal( auto *mySettings = MySettings::globalInstance(); // unpack prompt argument - const std::span *chatItems = nullptr; + const std::span *messageItems = nullptr; std::string jinjaBuffer; std::string_view conversation; if (auto *nonChat = std::get_if(&prompt)) { conversation = *nonChat; // complete the string without a template } else { - chatItems = &std::get>(prompt); - jinjaBuffer = applyJinjaTemplate(*chatItems); + messageItems = &std::get>(prompt); + jinjaBuffer = applyJinjaTemplate(*messageItems); conversation = jinjaBuffer; } @@ -916,8 +932,8 @@ auto ChatLLM::promptInternal( if (!dynamic_cast(m_llModelInfo.model.get())) { auto nCtx = m_llModelInfo.model->contextLength(); std::string jinjaBuffer2; - auto lastMessageRendered = (chatItems && chatItems->size() > 1) - ? std::string_view(jinjaBuffer2 = applyJinjaTemplate({ &chatItems->back(), 1 })) + auto lastMessageRendered = (messageItems && messageItems->size() > 1) + ? std::string_view(jinjaBuffer2 = applyJinjaTemplate({ &messageItems->back(), 1 })) : conversation; int32_t lastMessageLength = m_llModelInfo.model->countPromptTokens(lastMessageRendered); if (auto limit = nCtx - 4; lastMessageLength > limit) { @@ -937,14 +953,34 @@ auto ChatLLM::promptInternal( return !m_stopGenerating; }; - auto handleResponse = [this, &result](LLModel::Token token, std::string_view piece) -> bool { + ToolCallParser toolCallParser; + auto handleResponse = [this, &result, &toolCallParser](LLModel::Token token, std::string_view piece) -> bool { Q_UNUSED(token) result.responseTokens++; m_timer->inc(); + + // FIXME: This is *not* necessarily fully formed utf data because it can be partial at this point + // handle this like below where we have a QByteArray + toolCallParser.update(QString::fromStdString(piece.data())); + + // Create a toolcall and split the response if needed + if (!toolCallParser.hasSplit() && toolCallParser.state() == ToolEnums::ParseState::Partial) { + const QPair pair = toolCallParser.split(); + m_chatModel->splitToolCall(pair); + } + result.response.append(piece.data(), piece.size()); auto respStr = QString::fromUtf8(result.response); - emit responseChanged(removeLeadingWhitespace(respStr)); - return !m_stopGenerating; + + if (toolCallParser.hasSplit()) + m_chatModel->setResponseValue(toolCallParser.buffer()); + else + m_chatModel->setResponseValue(removeLeadingWhitespace(respStr)); + + emit responseChanged(); + + const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete; + return !foundToolCall && !m_stopGenerating; }; QElapsedTimer totalTime; @@ -955,6 +991,7 @@ auto ChatLLM::promptInternal( emit promptProcessing(); m_llModelInfo.model->setThreadCount(mySettings->threadCount()); m_stopGenerating = false; + // qDebug().noquote() << conversation; m_llModelInfo.model->prompt(conversation, handlePrompt, handleResponse, ctx); } catch (...) { m_timer->stop(); @@ -964,13 +1001,20 @@ auto ChatLLM::promptInternal( m_timer->stop(); qint64 elapsed = totalTime.elapsed(); + const bool foundToolCall = toolCallParser.state() == ToolEnums::ParseState::Complete; + // trim trailing whitespace auto respStr = QString::fromUtf8(result.response); - if (!respStr.isEmpty() && std::as_const(respStr).back().isSpace()) - emit responseChanged(respStr.trimmed()); + if (!respStr.isEmpty() && (std::as_const(respStr).back().isSpace() || foundToolCall)) { + if (toolCallParser.hasSplit()) + m_chatModel->setResponseValue(toolCallParser.buffer()); + else + m_chatModel->setResponseValue(respStr.trimmed()); + emit responseChanged(); + } bool doQuestions = false; - if (!m_isServer && chatItems) { + if (!m_isServer && messageItems && !foundToolCall) { switch (mySettings->suggestionMode()) { case SuggestionMode::On: doQuestions = true; break; case SuggestionMode::LocalDocsOnly: doQuestions = usedLocalDocs; break; diff --git a/gpt4all-chat/src/chatllm.h b/gpt4all-chat/src/chatllm.h index c79ca0bd517e..c241d7005529 100644 --- a/gpt4all-chat/src/chatllm.h +++ b/gpt4all-chat/src/chatllm.h @@ -30,7 +30,7 @@ using namespace Qt::Literals::StringLiterals; class QDataStream; -struct ChatItem; +struct MessageItem; // NOTE: values serialized to disk, do not change or reuse enum class LLModelTypeV0 { // chat versions 2-5 @@ -220,8 +220,8 @@ public Q_SLOTS: void modelLoadingPercentageChanged(float); void modelLoadingError(const QString &error); void modelLoadingWarning(const QString &warning); - void responseChanged(const QString &response); - void responseFailed(const QString &error); + void responseChanged(); + void responseFailed(); void promptProcessing(); void generatingQuestions(); void responseStopped(qint64 promptResponseMs); @@ -252,18 +252,18 @@ public Q_SLOTS: ChatPromptResult promptInternalChat(const QStringList &enabledCollections, const LLModel::PromptContext &ctx); // passing a string_view directly skips templating and uses the raw string - PromptResult promptInternal(const std::variant, std::string_view> &prompt, + PromptResult promptInternal(const std::variant, std::string_view> &prompt, const LLModel::PromptContext &ctx, bool usedLocalDocs); private: bool loadNewModel(const ModelInfo &modelInfo, QVariantMap &modelLoadProps); - std::vector forkConversation(const QString &prompt) const; + std::vector forkConversation(const QString &prompt) const; // Applies the Jinja template. Query mode returns only the last message without special tokens. // Returns a (# of messages, rendered prompt) pair. - std::string applyJinjaTemplate(std::span items) const; + std::string applyJinjaTemplate(std::span items) const; void generateQuestions(qint64 elapsed); diff --git a/gpt4all-chat/src/chatmodel.cpp b/gpt4all-chat/src/chatmodel.cpp new file mode 100644 index 000000000000..3bd7ed181217 --- /dev/null +++ b/gpt4all-chat/src/chatmodel.cpp @@ -0,0 +1,17 @@ +#include "chatmodel.h" + +MessageItem::MessageItem(const ChatItem *item) +{ + switch (item->type()) { + case ChatItem::Type::System: m_type = MessageItem::Type::System; break; + case ChatItem::Type::Prompt: m_type = MessageItem::Type::System; break; + case ChatItem::Type::Response: m_type = MessageItem::Type::System; break; + case ChatItem::Type::ToolResponse: m_type = MessageItem::Type::System; break; + case ChatItem::Type::Text: + case ChatItem::Type::ToolCall: + Q_UNREACHABLE(); + break; + } + + m_content = item->flattenedContent().toUtf8(); +} diff --git a/gpt4all-chat/src/chatmodel.h b/gpt4all-chat/src/chatmodel.h index 7ce6b0e884ad..0033071a51c6 100644 --- a/gpt4all-chat/src/chatmodel.h +++ b/gpt4all-chat/src/chatmodel.h @@ -2,6 +2,7 @@ #define CHATMODEL_H #include "database.h" +#include "toolcallparser.h" #include "utils.h" #include "xlsxtomd.h" @@ -11,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -69,19 +71,75 @@ struct PromptAttachment { }; Q_DECLARE_METATYPE(PromptAttachment) -struct ChatItem +class ChatItem; +class MessageItem { Q_GADGET + Q_PROPERTY(Type type READ type CONSTANT) + Q_PROPERTY(QByteArray content READ content CONSTANT) + +public: + enum class Type { System, Prompt, Response, ToolResponse }; + + MessageItem(Type type, const QByteArray &content) + : m_type(type), m_content(content) {} + MessageItem(const ChatItem *item); + MessageItem() = delete; + + Type type() const { return m_type; } + QByteArray content() const { return m_content; } + int peerIndex() const { return m_peerIndex; } + + QList sources() const { return m_sources; } + QList promptAttachments() const { return m_promptAttachments; } + + // used with version 0 Jinja templates + QString bakedPrompt() const + { + if (type() != Type::Prompt) + throw std::logic_error("bakedPrompt() called on non-prompt item"); + QStringList parts; + if (!m_sources.isEmpty()) { + parts << u"### Context:\n"_s; + for (auto &source : std::as_const(m_sources)) + parts << u"Collection: "_s << source.collection + << u"\nPath: "_s << source.path + << u"\nExcerpt: "_s << source.text << u"\n\n"_s; + } + for (auto &attached : std::as_const(m_promptAttachments)) + parts << attached.processedContent() << u"\n\n"_s; + parts << m_content; + return parts.join(QString()); + } + +private: + Type m_type; + QByteArray m_content; + int m_peerIndex = -1; + QList m_sources; + QList m_promptAttachments; +}; +Q_DECLARE_METATYPE(MessageItem) + +class ChatItem : public QObject +{ + Q_OBJECT Q_PROPERTY(QString name MEMBER name ) Q_PROPERTY(QString value MEMBER value) + // prompts and responses + Q_PROPERTY(QString content READ content NOTIFY contentChanged) + // prompts Q_PROPERTY(QList promptAttachments MEMBER promptAttachments) - Q_PROPERTY(QString bakedPrompt READ bakedPrompt ) // responses - Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse) - Q_PROPERTY(bool isError MEMBER isError ) + Q_PROPERTY(bool isCurrentResponse MEMBER isCurrentResponse NOTIFY isCurrentResponseChanged) + Q_PROPERTY(bool isError MEMBER isError ) + Q_PROPERTY(QList childItems READ childItems ) + + // toolcall + Q_PROPERTY(bool isToolCallError READ isToolCallError NOTIFY isTooCallErrorChanged) // responses (DataLake) Q_PROPERTY(QString newResponse MEMBER newResponse ) @@ -90,7 +148,7 @@ struct ChatItem Q_PROPERTY(bool thumbsDownState MEMBER thumbsDownState) public: - enum class Type { System, Prompt, Response }; + enum class Type { System, Prompt, Response, Text, ToolCall, ToolResponse }; // tags for constructing ChatItems struct prompt_tag_t { explicit prompt_tag_t() = default; }; @@ -99,20 +157,57 @@ struct ChatItem static inline constexpr response_tag_t response_tag = response_tag_t(); struct system_tag_t { explicit system_tag_t() = default; }; static inline constexpr system_tag_t system_tag = system_tag_t(); + struct text_tag_t { explicit text_tag_t() = default; }; + static inline constexpr text_tag_t text_tag = text_tag_t(); + struct tool_call_tag_t { explicit tool_call_tag_t() = default; }; + static inline constexpr tool_call_tag_t tool_call_tag = tool_call_tag_t(); + struct tool_response_tag_t { explicit tool_response_tag_t() = default; }; + static inline constexpr tool_response_tag_t tool_response_tag = tool_response_tag_t(); - // FIXME(jared): This should not be necessary. QML should see null or undefined if it - // tries to access something invalid. - ChatItem() = default; + ChatItem() = delete; // NOTE: system messages are currently never stored in the model or serialized - ChatItem(system_tag_t, const QString &value) - : name(u"System: "_s), value(value) {} + ChatItem(QObject *parent, system_tag_t, const QString &value) + : QObject(nullptr), name(u"System: "_s), value(value) + { + moveToThread(parent->thread()); + setParent(parent); + } - ChatItem(prompt_tag_t, const QString &value, const QList &attachments = {}) - : name(u"Prompt: "_s), value(value), promptAttachments(attachments) {} + ChatItem(QObject *parent, prompt_tag_t, const QString &value, const QList &attachments = {}) + : QObject(nullptr), name(u"Prompt: "_s), value(value), promptAttachments(attachments) + { + moveToThread(parent->thread()); + setParent(parent); + } - ChatItem(response_tag_t, bool isCurrentResponse = true) - : name(u"Response: "_s), isCurrentResponse(isCurrentResponse) {} + ChatItem(QObject *parent, response_tag_t, bool isCurrentResponse = true) + : QObject(nullptr), name(u"Response: "_s), isCurrentResponse(isCurrentResponse) + { + moveToThread(parent->thread()); + setParent(parent); + } + + ChatItem(QObject *parent, text_tag_t, const QString &value) + : QObject(nullptr), name(u"Text: "_s), value(value) + { + moveToThread(parent->thread()); + setParent(parent); + } + + ChatItem(QObject *parent, tool_call_tag_t, const QString &value) + : QObject(nullptr), name(u"ToolCall: "_s), value(value) + { + moveToThread(parent->thread()); + setParent(parent); + } + + ChatItem(QObject *parent, tool_response_tag_t, const QString &value) + : QObject(nullptr), name(u"ToolResponse: "_s), value(value) + { + moveToThread(parent->thread()); + setParent(parent); + } Type type() const { @@ -122,28 +217,144 @@ struct ChatItem return Type::Prompt; if (name == u"Response: "_s) return Type::Response; + if (name == u"Text: "_s) + return Type::Text; + if (name == u"ToolCall: "_s) + return Type::ToolCall; + if (name == u"ToolResponse: "_s) + return Type::ToolResponse; throw std::invalid_argument(fmt::format("Chat item has unknown label: {:?}", name)); } - // used with version 0 Jinja templates - QString bakedPrompt() const + QString flattenedContent() const { - if (type() != Type::Prompt) - throw std::logic_error("bakedPrompt() called on non-prompt item"); - QStringList parts; - if (!sources.isEmpty()) { - parts << u"### Context:\n"_s; - for (auto &source : std::as_const(sources)) - parts << u"Collection: "_s << source.collection - << u"\nPath: "_s << source.path - << u"\nExcerpt: "_s << source.text << u"\n\n"_s; + if (subItems.empty()) + return value; + + // We only flatten one level + QString content; + for (ChatItem *item : subItems) + content += item->value; + return content; + } + + QString content() const + { + if (type() == Type::Response) { + // We parse if this contains any part of a partial toolcall + ToolCallParser parser; + parser.update(value); + + // If no tool call is detected, return the original value + if (parser.startIndex() < 0) + return value; + + // Otherwise we only return the text before and any partial code interpreter code + const QString beforeToolCall = value.left(parser.startIndex()); + const QString toolCallString = value.mid(parser.startIndex()); + return beforeToolCall + codeInterpreterContent(toolCallString); } - for (auto &attached : std::as_const(promptAttachments)) - parts << attached.processedContent() << u"\n\n"_s; - parts << value; - return parts.join(QString()); + + // For complete tool calls we only return content if it is code interpreter + if (type() == Type::ToolCall) + return codeInterpreterContent(value); + + // We don't show any of content from the tool response in the GUI + if (type() == Type::ToolResponse) + return QString(); + + return value; + } + + QString codeInterpreterContent(const QString &value) const + { + // Constants for identifying and formatting the code interpreter tool call + static const QString prefix = ToolCallConstants::CodeInterpreterTag; + + // Check if the tool call is a code interpreter tool call + if (!value.startsWith(prefix)) + return QString(); + + // Regex to remove the tag and any surrounding whitespace + static const QRegularExpression regex("^" + + ToolCallConstants::CodeInterpreterTag + + "\\s*|\\s*" + + ToolCallConstants::CodeInterpreterEndTag + + "$"); + + // Extract the code + QString code = value; + code.remove(regex); + code = code.trimmed(); + + QString result; + + // If we've finished the tool call then extract the result from meta information + if (toolCallInfo.name == ToolCallConstants::CodeInterpreterFunction) + result = "```\n" + toolCallInfo.result + "```"; + + // Return the formatted code and the result if available + return code + result; + } + + QList childItems() const + { + // We currently have leaf nodes at depth 3 with nodes at depth 2 as mere containers we don't + // care about in GUI + QList items; + for (const ChatItem *item : subItems) + items << QList(item->subItems.begin(), item->subItems.end()); + return items; + } + + QString toolCallValue() const + { + if (!subItems.empty()) + return subItems.back()->toolCallValue(); + if (type() == Type::ToolCall) + return value; + else + return QString(); + } + + void setCurrentResponse(bool b) + { + if (!subItems.empty()) + subItems.back()->setCurrentResponse(b); + isCurrentResponse = b; + emit isCurrentResponseChanged(); + } + + void setValue(const QString &v) + { + if (!subItems.empty() && subItems.back()->isCurrentResponse) { + subItems.back()->setValue(v); + return; + } + + value = v; + emit contentChanged(); + } + + void setToolCallInfo(const ToolCallInfo &info) + { + toolCallInfo = info; + emit contentChanged(); + emit isTooCallErrorChanged(); } + bool isToolCallError() const + { + return toolCallInfo.error != ToolEnums::Error::NoError; + } + +Q_SIGNALS: + void contentChanged(); + void isTooCallErrorChanged(); + void isCurrentResponseChanged(); + +public: + // TODO: Maybe we should include the model name here as well as timestamp? QString name; QString value; @@ -156,6 +367,8 @@ struct ChatItem // responses bool isCurrentResponse = false; bool isError = false; + ToolCallInfo toolCallInfo; + std::list subItems; // responses (DataLake) QString newResponse; @@ -163,20 +376,6 @@ struct ChatItem bool thumbsUpState = false; bool thumbsDownState = false; }; -Q_DECLARE_METATYPE(ChatItem) - -class ChatModelAccessor : public ranges::subrange::const_iterator> { -private: - using Super = ranges::subrange::const_iterator>; - -public: - template - ChatModelAccessor(QMutex &mutex, T &&...args) - : Super(std::forward(args)...), m_lock(&mutex) {} - -private: - QMutexLocker m_lock; -}; class ChatModel : public QAbstractListModel { @@ -194,7 +393,7 @@ class ChatModel : public QAbstractListModel ValueRole, // prompts and responses - PeerRole, + ContentRole, // prompts PromptAttachmentsRole, @@ -205,6 +404,7 @@ class ChatModel : public QAbstractListModel ConsolidatedSourcesRole, IsCurrentResponseRole, IsErrorRole, + ChildItemsRole, // responses (DataLake) NewResponseRole, @@ -220,6 +420,7 @@ class ChatModel : public QAbstractListModel return m_chatItems.size(); } +#if 0 // FIXME /* a "peer" is a bidirectional 1:1 link between a prompt and the response that would cite its LocalDocs * sources. Return std::nullopt if there is none, which is possible for e.g. server chats. */ auto getPeerUnlocked(QList::const_iterator item) const @@ -253,6 +454,7 @@ class ChatModel : public QAbstractListModel return getPeerUnlocked(m_chatItems.cbegin() + index) .transform([&](auto &&i) { return i - m_chatItems.cbegin(); } ); } +#endif QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override { @@ -260,42 +462,34 @@ class ChatModel : public QAbstractListModel if (!index.isValid() || index.row() < 0 || index.row() >= m_chatItems.size()) return QVariant(); - auto item = m_chatItems.cbegin() + index.row(); + const ChatItem *item = m_chatItems.at(index.row()); switch (role) { case NameRole: return item->name; case ValueRole: return item->value; - case PeerRole: - switch (item->type()) { - using enum ChatItem::Type; - case Prompt: - case Response: - { - auto peer = getPeerUnlocked(item); - return peer ? QVariant::fromValue(**peer) : QVariant::fromValue(nullptr); - } - default: - return QVariant(); - } case PromptAttachmentsRole: return QVariant::fromValue(item->promptAttachments); case SourcesRole: { QList data; +#if 0 // FIXME if (item->type() == ChatItem::Type::Response) { if (auto prompt = getPeerUnlocked(item)) data = (*prompt)->consolidatedSources; } +#endif return QVariant::fromValue(data); } case ConsolidatedSourcesRole: { QList data; +#if 0 // FIXME if (item->type() == ChatItem::Type::Response) { if (auto prompt = getPeerUnlocked(item)) data = (*prompt)->sources; } +#endif return QVariant::fromValue(data); } case IsCurrentResponseRole: @@ -310,6 +504,10 @@ class ChatModel : public QAbstractListModel return item->thumbsDownState; case IsErrorRole: return item->type() == ChatItem::Type::Response && item->isError; + case ContentRole: + return item->content(); + case ChildItemsRole: + return QVariant::fromValue(item->childItems()); } return QVariant(); @@ -320,7 +518,6 @@ class ChatModel : public QAbstractListModel return { { NameRole, "name" }, { ValueRole, "value" }, - { PeerRole, "peer" }, { PromptAttachmentsRole, "promptAttachments" }, { SourcesRole, "sources" }, { ConsolidatedSourcesRole, "consolidatedSources" }, @@ -330,6 +527,8 @@ class ChatModel : public QAbstractListModel { StoppedRole, "stopped" }, { ThumbsUpStateRole, "thumbsUpState" }, { ThumbsDownStateRole, "thumbsDownState" }, + { ContentRole, "content" }, + { ChildItemsRole, "childItems" }, }; } @@ -346,7 +545,8 @@ class ChatModel : public QAbstractListModel beginInsertRows(QModelIndex(), count, count); { QMutexLocker locker(&m_mutex); - m_chatItems.emplace_back(ChatItem::prompt_tag, value, attachments); + ChatItem *item = new ChatItem(this, ChatItem::prompt_tag, value, attachments); + m_chatItems.emplace_back(item); } endInsertRows(); emit countChanged(); @@ -368,20 +568,19 @@ class ChatModel : public QAbstractListModel if (promptIndex >= 0) { if (promptIndex >= m_chatItems.size()) throw std::out_of_range(fmt::format("index {} is out of range", promptIndex)); - auto &promptItem = m_chatItems[promptIndex]; - if (promptItem.type() != ChatItem::Type::Prompt) + ChatItem *promptItem = m_chatItems[promptIndex]; + if (promptItem->type() != ChatItem::Type::Prompt) throw std::invalid_argument(fmt::format("item at index {} is not a prompt", promptIndex)); } - m_chatItems.emplace_back(ChatItem::response_tag, promptIndex); + ChatItem *item = new ChatItem(this, ChatItem::response_tag); + m_chatItems.emplace_back(item); } endInsertRows(); emit countChanged(); - if (promptIndex >= 0) - emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole}); } // Used by Server to append a new conversation to the chat log. - void appendResponseWithHistory(std::span history) + void appendResponseWithHistory(std::span history) { if (history.empty()) throw std::invalid_argument("at least one message is required"); @@ -394,22 +593,24 @@ class ChatModel : public QAbstractListModel qsizetype endIndex = startIndex + nNewItems; beginInsertRows(QModelIndex(), startIndex, endIndex - 1 /*inclusive*/); bool hadError; +#if 0 // FIXME int promptIndex; +#endif { QMutexLocker locker(&m_mutex); hadError = hasErrorUnlocked(); m_chatItems.reserve(m_chatItems.count() + nNewItems); +#if 0 // FIXME for (auto &item : history) m_chatItems << item; m_chatItems.emplace_back(ChatItem::response_tag); +#endif } endInsertRows(); emit countChanged(); // Server can add messages when there is an error because each call is a new conversation if (hadError) emit hasErrorChanged(false); - if (promptIndex >= 0) - emit dataChanged(createIndex(promptIndex, 0), createIndex(promptIndex, 0), {PeerRole}); } void truncate(qsizetype size) @@ -419,7 +620,7 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (size >= (oldSize = m_chatItems.size())) return; - if (size && m_chatItems.at(size - 1).type() != ChatItem::Type::Response) + if (size && m_chatItems.at(size - 1)->type() != ChatItem::Type::Response) throw std::invalid_argument( fmt::format("chat model truncated to {} items would not end in a response", size) ); @@ -459,28 +660,24 @@ class ChatModel : public QAbstractListModel emit hasErrorChanged(false); } - Q_INVOKABLE ChatItem get(int index) + Q_INVOKABLE ChatItem *get(int index) { QMutexLocker locker(&m_mutex); - if (index < 0 || index >= m_chatItems.size()) return ChatItem(); + if (index < 0 || index >= m_chatItems.size()) return nullptr; return m_chatItems.at(index); } Q_INVOKABLE void updateCurrentResponse(int index, bool b) { - bool changed = false; { QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.isCurrentResponse != b) { - item.isCurrentResponse = b; - changed = true; - } + ChatItem *item = m_chatItems[index]; + item->setCurrentResponse(b); } - if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole}); + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsCurrentResponseRole}); } Q_INVOKABLE void updateStopped(int index, bool b) @@ -490,32 +687,28 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.stopped != b) { - item.stopped = b; + ChatItem *item = m_chatItems[index]; + if (item->stopped != b) { + item->stopped = b; changed = true; } } if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {StoppedRole}); } - Q_INVOKABLE void updateValue(int index, const QString &value) + Q_INVOKABLE void setResponseValue(const QString &value) { - bool changed = false; + qsizetype index; { QMutexLocker locker(&m_mutex); - if (index < 0 || index >= m_chatItems.size()) return; + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("we only set this on a response"); - ChatItem &item = m_chatItems[index]; - if (item.value != value) { - item.value = value; - changed = true; - } - } - if (changed) { - emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole}); - emit valueChanged(index, value); + index = m_chatItems.count() - 1; + ChatItem *item = m_chatItems.back(); + item->setValue(value); } + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ValueRole, ContentRole}); } static QList consolidateSources(const QList &sources) { @@ -538,6 +731,7 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; +#if 0 // FIXME auto promptItem = m_chatItems.begin() + index; if (promptItem->type() != ChatItem::Type::Prompt) throw std::invalid_argument(fmt::format("item at index {} is not a prompt", index)); @@ -545,6 +739,7 @@ class ChatModel : public QAbstractListModel responseIndex = *peer - m_chatItems.cbegin(); promptItem->sources = sources; promptItem->consolidatedSources = consolidateSources(sources); +#endif } if (responseIndex >= 0) { emit dataChanged(createIndex(responseIndex, 0), createIndex(responseIndex, 0), {SourcesRole}); @@ -559,9 +754,9 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.thumbsUpState != b) { - item.thumbsUpState = b; + ChatItem *item = m_chatItems[index]; + if (item->thumbsUpState != b) { + item->thumbsUpState = b; changed = true; } } @@ -575,9 +770,9 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.thumbsDownState != b) { - item.thumbsDownState = b; + ChatItem *item = m_chatItems[index]; + if (item->thumbsDownState != b) { + item->thumbsDownState = b; changed = true; } } @@ -591,29 +786,116 @@ class ChatModel : public QAbstractListModel QMutexLocker locker(&m_mutex); if (index < 0 || index >= m_chatItems.size()) return; - ChatItem &item = m_chatItems[index]; - if (item.newResponse != newResponse) { - item.newResponse = newResponse; + ChatItem *item = m_chatItems[index]; + if (item->newResponse != newResponse) { + item->newResponse = newResponse; changed = true; } } if (changed) emit dataChanged(createIndex(index, 0), createIndex(index, 0), {NewResponseRole}); } + Q_INVOKABLE void splitToolCall(const QPair &split) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only set toolcall on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *currentResponse = m_chatItems.back(); + Q_ASSERT(currentResponse->isCurrentResponse); + + // Create a new response container for any text and the tool call + ChatItem *newResponse = new ChatItem(this, ChatItem::response_tag, true /*isCurrentResponse*/); + + // Add preceding text if any + if (!split.first.isEmpty()) { + ChatItem *textItem = new ChatItem(this, ChatItem::text_tag, split.first); + newResponse->subItems.push_back(textItem); + } + + // Add the toolcall + Q_ASSERT(!split.second.isEmpty()); + ChatItem *toolCallItem = new ChatItem(this, ChatItem::tool_call_tag, split.second); + toolCallItem->isCurrentResponse = true; + // toolCallItem.toolCallInfo = toolCallInfo; + newResponse->subItems.push_back(toolCallItem); + + // Add new response and reset our value + currentResponse->subItems.push_back(newResponse); + currentResponse->value = QString(); + } + + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + + Q_INVOKABLE void updateToolCall(const ToolCallInfo &toolCallInfo) + { + qsizetype index; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only set toolcall on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *currentResponse = m_chatItems.back(); + Q_ASSERT(currentResponse->isCurrentResponse); + + ChatItem *subResponse = currentResponse->subItems.back(); + Q_ASSERT(subResponse->type() == ChatItem::Type::Response); + Q_ASSERT(subResponse->isCurrentResponse); + + ChatItem *toolCallItem = subResponse->subItems.back(); + Q_ASSERT(toolCallItem->type() == ChatItem::Type::ToolCall); + toolCallItem->setToolCallInfo(toolCallInfo); + toolCallItem->setCurrentResponse(false); + + // Add tool response + ChatItem *toolResponseItem = new ChatItem(this, ChatItem::tool_response_tag, toolCallInfo.result); + currentResponse->subItems.push_back(toolResponseItem); + } + + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + + Q_INVOKABLE void clearSubItems() + { + qsizetype index; + bool changed = false; + { + QMutexLocker locker(&m_mutex); + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) + throw std::logic_error("can only clear subitems on a chat that ends with a response"); + + index = m_chatItems.count() - 1; + ChatItem *item = m_chatItems.back(); + if (!item->subItems.empty()) { + item->subItems.clear(); + changed = true; + } + } + if (changed) { + qDebug() << "signaling we've cleared the subitems."; + emit dataChanged(createIndex(index, 0), createIndex(index, 0), {ChildItemsRole, ContentRole}); + } + } + Q_INVOKABLE void setError(bool value = true) { qsizetype index; { QMutexLocker locker(&m_mutex); - if (m_chatItems.isEmpty() || m_chatItems.cend()[-1].type() != ChatItem::Type::Response) + if (m_chatItems.isEmpty() || m_chatItems.cend()[-1]->type() != ChatItem::Type::Response) throw std::logic_error("can only set error on a chat that ends with a response"); index = m_chatItems.count() - 1; auto &last = m_chatItems.back(); - if (last.isError == value) + if (last->isError == value) return; // already set - last.isError = value; + last->isError = value; } emit dataChanged(createIndex(index, 0), createIndex(index, 0), {IsErrorRole}); emit hasErrorChanged(value); @@ -621,12 +903,24 @@ class ChatModel : public QAbstractListModel qsizetype count() const { QMutexLocker locker(&m_mutex); return m_chatItems.size(); } - ChatModelAccessor chatItems() const { return {m_mutex, std::as_const(m_chatItems)}; } + std::vector messageItems() const + { + // A flattened version of the chat item tree used by the backend and jinja + std::vector chatItems; + for (const ChatItem *item : m_chatItems) { + if (!item->subItems.empty()) + chatItems.insert(chatItems.end(), item->subItems.begin(), item->subItems.end()); + chatItems.push_back(item); + } + return chatItems; + } bool hasError() const { QMutexLocker locker(&m_mutex); return hasErrorUnlocked(); } bool serialize(QDataStream &stream, int version) const { +#if 0 + // FIXME: need to serialize the toolcall info QMutexLocker locker(&m_mutex); stream << int(m_chatItems.size()); for (auto itemIt = m_chatItems.cbegin(); itemIt < m_chatItems.cend(); ++itemIt) { @@ -724,10 +1018,14 @@ class ChatModel : public QAbstractListModel } } return stream.status() == QDataStream::Ok; +#else + return false; +#endif } bool deserialize(QDataStream &stream, int version) { +#if 0 clear(); // reset to known state int size; @@ -907,11 +1205,13 @@ class ChatModel : public QAbstractListModel if (hasError) emit hasErrorChanged(true); return stream.status() == QDataStream::Ok; +#else + return false; +#endif } Q_SIGNALS: void countChanged(); - void valueChanged(int index, const QString &value); void hasErrorChanged(bool value); private: @@ -920,12 +1220,12 @@ class ChatModel : public QAbstractListModel if (m_chatItems.isEmpty()) return false; auto &last = m_chatItems.back(); - return last.type() == ChatItem::Type::Response && last.isError; + return last->type() == ChatItem::Type::Response && last->isError; } private: mutable QMutex m_mutex; - QList m_chatItems; + QList m_chatItems; }; #endif // CHATMODEL_H diff --git a/gpt4all-chat/src/codeinterpreter.cpp b/gpt4all-chat/src/codeinterpreter.cpp new file mode 100644 index 000000000000..1ea8b949e107 --- /dev/null +++ b/gpt4all-chat/src/codeinterpreter.cpp @@ -0,0 +1,125 @@ +#include "codeinterpreter.h" + +#include +#include + +using namespace Qt::Literals::StringLiterals; + +QString CodeInterpreter::run(const QList ¶ms, qint64 timeout) +{ + m_error = ToolEnums::Error::NoError; + m_errorString = QString(); + + Q_ASSERT(params.count() == 1 + && params.first().name == "code" + && params.first().type == ToolEnums::ParamType::String); + + const QString code = params.first().value.toString(); + + QThread workerThread; + CodeInterpreterWorker worker; + worker.moveToThread(&workerThread); + connect(&worker, &CodeInterpreterWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection); + connect(&workerThread, &QThread::started, [&worker, code]() { + worker.request(code); + }); + workerThread.start(); + bool timedOut = !workerThread.wait(timeout); + if (timedOut) { + worker.interrupt(); // thread safe + m_error = ToolEnums::Error::TimeoutError; + } + workerThread.quit(); + workerThread.wait(); + if (!timedOut) { + m_error = worker.error(); + m_errorString = worker.errorString(); + } + return worker.response(); +} + +QList CodeInterpreter::parameters() const +{ + return {{ + "code", + ToolEnums::ParamType::String, + "javascript code to compute", + true + }}; +} + +QString CodeInterpreter::symbolicFormat() const +{ + return "{human readable plan to complete the task}\n" + ToolCallConstants::CodeInterpreterPrefix + "{code}\n" + ToolCallConstants::CodeInterpreterSuffix; +} + +QString CodeInterpreter::examplePrompt() const +{ + return R"(Write code to check if a number is prime, use that to see if the number 7 is prime)"; +} + +QString CodeInterpreter::exampleCall() const +{ + static const QString example = R"(function isPrime(n) { + if (n <= 1) { + return false; + } + for (let i = 2; i <= Math.sqrt(n); i++) { + if (n % i === 0) { + return false; + } + } + return true; +} + +const number = 7; +console.log(`The number ${number} is prime: ${isPrime(number)}`); +)"; + + return "Certainly! Let's compute the answer to whether the number 7 is prime.\n" + ToolCallConstants::CodeInterpreterPrefix + example + ToolCallConstants::CodeInterpreterSuffix; +} + +QString CodeInterpreter::exampleReply() const +{ + return R"("The computed result shows that 7 is a prime number.)"; +} + +CodeInterpreterWorker::CodeInterpreterWorker() + : QObject(nullptr) +{ +} + +void CodeInterpreterWorker::request(const QString &code) +{ + JavaScriptConsoleCapture consoleCapture; + QJSValue consoleObject = m_engine.newQObject(&consoleCapture); + m_engine.globalObject().setProperty("console", consoleObject); + + const QJSValue result = m_engine.evaluate(code); + QString resultString = result.isUndefined() ? QString() : result.toString(); + + // NOTE: We purposely do not set the m_error or m_errorString for the code interpreter since + // we *want* the model to see the response has an error so it can hopefully correct itself. The + // error member variables are intended for tools that have error conditions that cannot be corrected. + // For instance, a tool depending upon the network might set these error variables if the network + // is not available. + if (result.isError()) { + const QStringList lines = code.split('\n'); + const int line = result.property("lineNumber").toInt(); + const int index = line - 1; + const QString lineContent = (index >= 0 && index < lines.size()) ? lines.at(index) : "Line not found in code."; + resultString = QString("Uncaught exception at line %1: %2\n\t%3") + .arg(line) + .arg(result.toString()) + .arg(lineContent); + m_error = ToolEnums::Error::UnknownError; + m_errorString = resultString; + } + + if (resultString.isEmpty()) + resultString = consoleCapture.output; + else if (!consoleCapture.output.isEmpty()) + resultString += "\n" + consoleCapture.output; + m_response = resultString; + emit finished(); +} diff --git a/gpt4all-chat/src/codeinterpreter.h b/gpt4all-chat/src/codeinterpreter.h new file mode 100644 index 000000000000..4c501497d336 --- /dev/null +++ b/gpt4all-chat/src/codeinterpreter.h @@ -0,0 +1,83 @@ +#ifndef CODEINTERPRETER_H +#define CODEINTERPRETER_H + +#include "tool.h" +#include "toolcallparser.h" + +#include +#include +#include + +class JavaScriptConsoleCapture : public QObject +{ + Q_OBJECT +public: + QString output; + Q_INVOKABLE void log(const QString &message) + { + const int maxLength = 1024; + if (output.length() >= maxLength) + return; + + if (output.length() + message.length() + 1 > maxLength) { + static const QString trunc = "\noutput truncated at " + QString::number(maxLength) + " characters..."; + int remainingLength = maxLength - output.length(); + if (remainingLength > 0) + output.append(message.left(remainingLength)); + output.append(trunc); + Q_ASSERT(output.length() > maxLength); + } else { + output.append(message + "\n"); + } + } +}; + +class CodeInterpreterWorker : public QObject { + Q_OBJECT +public: + CodeInterpreterWorker(); + virtual ~CodeInterpreterWorker() {} + + QString response() const { return m_response; } + + void request(const QString &code); + void interrupt() { m_engine.setInterrupted(true); } + ToolEnums::Error error() const { return m_error; } + QString errorString() const { return m_errorString; } + +Q_SIGNALS: + void finished(); + +private: + QJSEngine m_engine; + QString m_response; + ToolEnums::Error m_error = ToolEnums::Error::NoError; + QString m_errorString; +}; + +class CodeInterpreter : public Tool +{ + Q_OBJECT +public: + explicit CodeInterpreter() : Tool(), m_error(ToolEnums::Error::NoError) {} + virtual ~CodeInterpreter() {} + + QString run(const QList ¶ms, qint64 timeout = 2000) override; + ToolEnums::Error error() const override { return m_error; } + QString errorString() const override { return m_errorString; } + + QString name() const override { return tr("Code Interpreter"); } + QString description() const override { return tr("compute javascript code using console.log as output"); } + QString function() const override { return ToolCallConstants::CodeInterpreterFunction; } + QList parameters() const override; + virtual QString symbolicFormat() const override; + QString examplePrompt() const override; + QString exampleCall() const override; + QString exampleReply() const override; + +private: + ToolEnums::Error m_error = ToolEnums::Error::NoError; + QString m_errorString; +}; + +#endif // CODEINTERPRETER_H diff --git a/gpt4all-chat/src/jinja_helpers.cpp b/gpt4all-chat/src/jinja_helpers.cpp index 826dfb01e812..133e58bc95ba 100644 --- a/gpt4all-chat/src/jinja_helpers.cpp +++ b/gpt4all-chat/src/jinja_helpers.cpp @@ -51,12 +51,14 @@ auto JinjaMessage::keys() const -> const std::unordered_set & static const std::unordered_set userKeys { "role", "content", "sources", "prompt_attachments" }; switch (m_item->type()) { - using enum ChatItem::Type; + using enum MessageItem::Type; case System: case Response: + case ToolResponse: return baseKeys; case Prompt: return userKeys; + break; } Q_UNREACHABLE(); } @@ -67,16 +69,18 @@ bool operator==(const JinjaMessage &a, const JinjaMessage &b) return true; const auto &[ia, ib] = std::tie(*a.m_item, *b.m_item); auto type = ia.type(); - if (type != ib.type() || ia.value != ib.value) + if (type != ib.type() || ia.content() != ib.content()) return false; switch (type) { - using enum ChatItem::Type; + using enum MessageItem::Type; case System: case Response: + case ToolResponse: return true; case Prompt: - return ia.sources == ib.sources && ia.promptAttachments == ib.promptAttachments; + return ia.sources() == ib.sources() && ia.promptAttachments() == ib.promptAttachments(); + break; } Q_UNREACHABLE(); } @@ -84,26 +88,28 @@ bool operator==(const JinjaMessage &a, const JinjaMessage &b) const JinjaFieldMap JinjaMessage::s_fields = { { "role", [](auto &m) { switch (m.item().type()) { - using enum ChatItem::Type; + using enum MessageItem::Type; case System: return "system"sv; case Prompt: return "user"sv; case Response: return "assistant"sv; + case ToolResponse: return "tool"sv; + break; } Q_UNREACHABLE(); } }, { "content", [](auto &m) { - if (m.version() == 0 && m.item().type() == ChatItem::Type::Prompt) + if (m.version() == 0 && m.item().type() == MessageItem::Type::Prompt) return m.item().bakedPrompt().toStdString(); - return m.item().value.toStdString(); + return m.item().content().toStdString(); } }, { "sources", [](auto &m) { - auto sources = m.item().sources | views::transform([](auto &r) { + auto sources = m.item().sources() | views::transform([](auto &r) { return jinja2::GenericMap([map = std::make_shared(r)] { return map.get(); }); }); return jinja2::ValuesList(sources.begin(), sources.end()); } }, { "prompt_attachments", [](auto &m) { - auto attachments = m.item().promptAttachments | views::transform([](auto &pa) { + auto attachments = m.item().promptAttachments() | views::transform([](auto &pa) { return jinja2::GenericMap([map = std::make_shared(pa)] { return map.get(); }); }); return jinja2::ValuesList(attachments.begin(), attachments.end()); diff --git a/gpt4all-chat/src/jinja_helpers.h b/gpt4all-chat/src/jinja_helpers.h index a196b47f8fdf..f7f4ff9b8b61 100644 --- a/gpt4all-chat/src/jinja_helpers.h +++ b/gpt4all-chat/src/jinja_helpers.h @@ -86,12 +86,12 @@ class JinjaPromptAttachment : public JinjaHelper { class JinjaMessage : public JinjaHelper { public: - explicit JinjaMessage(uint version, const ChatItem &item) noexcept + explicit JinjaMessage(uint version, const MessageItem &item) noexcept : m_version(version), m_item(&item) {} const JinjaMessage &value () const { return *this; } uint version() const { return m_version; } - const ChatItem &item () const { return *m_item; } + const MessageItem &item () const { return *m_item; } size_t GetSize() const override { return keys().size(); } bool HasValue(const std::string &name) const override { return keys().contains(name); } @@ -107,7 +107,7 @@ class JinjaMessage : public JinjaHelper { private: static const JinjaFieldMap s_fields; uint m_version; - const ChatItem *m_item; + const MessageItem *m_item; friend class JinjaHelper; friend bool operator==(const JinjaMessage &a, const JinjaMessage &b); diff --git a/gpt4all-chat/src/main.cpp b/gpt4all-chat/src/main.cpp index 0fc23be3c961..1050e590879d 100644 --- a/gpt4all-chat/src/main.cpp +++ b/gpt4all-chat/src/main.cpp @@ -7,6 +7,7 @@ #include "modellist.h" #include "mysettings.h" #include "network.h" +#include "toolmodel.h" #include #include @@ -116,6 +117,8 @@ int main(int argc, char *argv[]) qmlRegisterSingletonInstance("download", 1, 0, "Download", Download::globalInstance()); qmlRegisterSingletonInstance("network", 1, 0, "Network", Network::globalInstance()); qmlRegisterSingletonInstance("localdocs", 1, 0, "LocalDocs", LocalDocs::globalInstance()); + qmlRegisterSingletonInstance("toollist", 1, 0, "ToolList", ToolModel::globalInstance()); + qmlRegisterUncreatableMetaObject(ToolEnums::staticMetaObject, "toolenums", 1, 0, "ToolEnums", "Error: only enums"); qmlRegisterUncreatableMetaObject(MySettingsEnums::staticMetaObject, "mysettingsenums", 1, 0, "MySettingsEnums", "Error: only enums"); { diff --git a/gpt4all-chat/src/server.cpp b/gpt4all-chat/src/server.cpp index 2435d43b3c9e..cf97e2094a90 100644 --- a/gpt4all-chat/src/server.cpp +++ b/gpt4all-chat/src/server.cpp @@ -693,7 +693,8 @@ auto Server::handleCompletionRequest(const CompletionRequest &request) promptCtx, /*usedLocalDocs*/ false); } catch (const std::exception &e) { - emit responseChanged(e.what()); + m_chatModel->setResponseValue(e.what()); + m_chatModel->setError(); emit responseStopped(0); return makeError(QHttpServerResponder::StatusCode::InternalServerError); } @@ -771,16 +772,16 @@ auto Server::handleChatRequest(const ChatRequest &request) Q_ASSERT(!request.messages.isEmpty()); // adds prompt/response items to GUI - std::vector chatItems; + std::vector messageItems; for (auto &message : request.messages) { using enum ChatRequest::Message::Role; switch (message.role) { - case System: chatItems.emplace_back(ChatItem::system_tag, message.content); break; - case User: chatItems.emplace_back(ChatItem::prompt_tag, message.content); break; - case Assistant: chatItems.emplace_back(ChatItem::response_tag, /*currentResponse*/ false); break; + case System: messageItems.emplace_back(MessageItem(MessageItem::Type::System, message.content.toUtf8())); break; + case User: messageItems.emplace_back(MessageItem(MessageItem::Type::Prompt, message.content.toUtf8())); break; + case Assistant: messageItems.emplace_back(MessageItem(MessageItem::Type::Response, message.content.toUtf8())); break; } } - m_chatModel->appendResponseWithHistory(chatItems); + m_chatModel->appendResponseWithHistory(messageItems); // FIXME(jared): taking parameters from the UI inhibits reproducibility of results LLModel::PromptContext promptCtx { @@ -802,7 +803,8 @@ auto Server::handleChatRequest(const ChatRequest &request) try { result = promptInternalChat(m_collections, promptCtx); } catch (const std::exception &e) { - emit responseChanged(e.what()); + m_chatModel->setResponseValue(e.what()); + m_chatModel->setError(); emit responseStopped(0); return makeError(QHttpServerResponder::StatusCode::InternalServerError); } diff --git a/gpt4all-chat/src/tool.cpp b/gpt4all-chat/src/tool.cpp new file mode 100644 index 000000000000..55a2fed2e298 --- /dev/null +++ b/gpt4all-chat/src/tool.cpp @@ -0,0 +1,41 @@ +#include "tool.h" + +#include + +jinja2::Value Tool::jinjaValue() const +{ + jinja2::ValuesList paramList; + const QList p = parameters(); + for (auto &info : p) { + std::string typeStr; + switch (info.type) { + using enum ToolEnums::ParamType; + case String: typeStr = "string"; break; + case Number: typeStr = "number"; break; + case Integer: typeStr = "integer"; break; + case Object: typeStr = "object"; break; + case Array: typeStr = "array"; break; + case Boolean: typeStr = "boolean"; break; + case Null: typeStr = "null"; break; + } + jinja2::ValuesMap infoMap { + { "name", info.name.toStdString() }, + { "type", typeStr}, + { "description", info.description.toStdString() }, + { "required", info.required } + }; + paramList.push_back(infoMap); + } + + jinja2::ValuesMap params { + { "name", name().toStdString() }, + { "description", description().toStdString() }, + { "function", function().toStdString() }, + { "parameters", paramList }, + { "symbolicFormat", symbolicFormat().toStdString() }, + { "examplePrompt", examplePrompt().toStdString() }, + { "exampleCall", exampleCall().toStdString() }, + { "exampleReply", exampleReply().toStdString() } + }; + return params; +} diff --git a/gpt4all-chat/src/tool.h b/gpt4all-chat/src/tool.h new file mode 100644 index 000000000000..6a992f5caa85 --- /dev/null +++ b/gpt4all-chat/src/tool.h @@ -0,0 +1,122 @@ +#ifndef TOOL_H +#define TOOL_H + +#include +#include + +#include + +using namespace Qt::Literals::StringLiterals; + +namespace ToolEnums +{ + Q_NAMESPACE + enum class Error + { + NoError = 0, + TimeoutError = 2, + UnknownError = 499, + }; + Q_ENUM_NS(Error) + + enum class ParamType { String, Number, Integer, Object, Array, Boolean, Null }; // json schema types + Q_ENUM_NS(ParamType) + + enum class ParseState { + None, + InStart, + Partial, + Complete, + }; + Q_ENUM_NS(ParseState) +} + +struct ToolParamInfo +{ + QString name; + ToolEnums::ParamType type; + QString description; + bool required; +}; +Q_DECLARE_METATYPE(ToolParamInfo) + +struct ToolParam +{ + QString name; + ToolEnums::ParamType type; + QVariant value; + bool operator==(const ToolParam& other) const + { + return name == other.name && type == other.type && value == other.value; + } +}; +Q_DECLARE_METATYPE(ToolParam) + +struct ToolCallInfo +{ + QString name; + QList params; + QString result; + ToolEnums::Error error = ToolEnums::Error::NoError; + QString errorString; + bool operator==(const ToolCallInfo& other) const + { + return name == other.name && result == other.result && params == other.params + && error == other.error && errorString == other.errorString; + } +}; +Q_DECLARE_METATYPE(ToolCallInfo) + +class Tool : public QObject +{ + Q_OBJECT + Q_PROPERTY(QString name READ name CONSTANT) + Q_PROPERTY(QString description READ description CONSTANT) + Q_PROPERTY(QString function READ function CONSTANT) + Q_PROPERTY(QList parameters READ parameters CONSTANT) + Q_PROPERTY(QString examplePrompt READ examplePrompt CONSTANT) + Q_PROPERTY(QString exampleCall READ exampleCall CONSTANT) + Q_PROPERTY(QString exampleReply READ exampleReply CONSTANT) + +public: + Tool() : QObject(nullptr) {} + virtual ~Tool() {} + + virtual QString run(const QList ¶ms, qint64 timeout = 2000) = 0; + + // Tools should set these if they encounter errors. For instance, a tool depending upon the network + // might set these error variables if the network is not available. + virtual ToolEnums::Error error() const { return ToolEnums::Error::NoError; } + virtual QString errorString() const { return QString(); } + + // [Required] Human readable name of the tool. + virtual QString name() const = 0; + + // [Required] Human readable description of what the tool does. Use this tool to: {{description}} + virtual QString description() const = 0; + + // [Required] Must be unique. Name of the function to invoke. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + virtual QString function() const = 0; + + // [Optional] List describing the tool's parameters. An empty list specifies no parameters. + virtual QList parameters() const { return {}; } + + // [Optional] The symbolic format of the toolcall. + virtual QString symbolicFormat() const { return QString(); } + + // [Optional] A human generated example of a prompt that could result in this tool being called. + virtual QString examplePrompt() const { return QString(); } + + // [Optional] An example of this tool call that pairs with the example query. It should be the + // complete string that the model must generate. + virtual QString exampleCall() const { return QString(); } + + // [Optional] An example of the reply the model might generate given the result of the tool call. + virtual QString exampleReply() const { return QString(); } + + bool operator==(const Tool &other) const { return function() == other.function(); } + + jinja2::Value jinjaValue() const; +}; + +#endif // TOOL_H diff --git a/gpt4all-chat/src/toolcallparser.cpp b/gpt4all-chat/src/toolcallparser.cpp new file mode 100644 index 000000000000..af2e707ba914 --- /dev/null +++ b/gpt4all-chat/src/toolcallparser.cpp @@ -0,0 +1,107 @@ +#include "toolcallparser.h" + +#include + +static const QString ToolCallStart = ToolCallConstants::CodeInterpreterTag; +static const QString ToolCallEnd = ToolCallConstants::CodeInterpreterEndTag; + +ToolCallParser::ToolCallParser() +{ + reset(); +} + +void ToolCallParser::reset() +{ + // Resets the search state, but not the buffer or global state + resetSearchState(); + + // These are global states maintained between update calls + m_buffer.clear(); + m_hasSplit = false; +} + +void ToolCallParser::resetSearchState() +{ + m_expected = ToolCallStart.at(0); + m_expectedIndex = 0; + m_state = ToolEnums::ParseState::None; + m_toolCall.clear(); + m_endTagBuffer.clear(); + m_startIndex = -1; +} + +// This method is called with an arbitrary string and a current state. This method should take the +// current state into account and then parse through the update character by character to arrive at +// the new state. +void ToolCallParser::update(const QString &update) +{ + Q_ASSERT(m_state != ToolEnums::ParseState::Complete); + if (m_state == ToolEnums::ParseState::Complete) { + qWarning() << "ERROR: ToolCallParser::update already found a complete toolcall!"; + return; + } + + m_buffer.append(update); + + for (size_t i = m_buffer.size() - update.size(); i < m_buffer.size(); ++i) { + const QChar c = m_buffer[i]; + const bool foundMatch = m_expected.isNull() || c == m_expected; + if (!foundMatch) { + resetSearchState(); + continue; + } + + switch (m_state) { + case ToolEnums::ParseState::None: + { + m_expectedIndex = 1; + m_expected = ToolCallStart.at(1); + m_state = ToolEnums::ParseState::InStart; + m_startIndex = i; + break; + } + case ToolEnums::ParseState::InStart: + { + if (m_expectedIndex == ToolCallStart.size() - 1) { + m_expectedIndex = 0; + m_expected = QChar(); + m_state = ToolEnums::ParseState::Partial; + } else { + ++m_expectedIndex; + m_expected = ToolCallStart.at(m_expectedIndex); + } + break; + } + case ToolEnums::ParseState::Partial: + { + m_toolCall.append(c); + m_endTagBuffer.append(c); + if (m_endTagBuffer.size() > ToolCallEnd.size()) + m_endTagBuffer.remove(0, 1); + if (m_endTagBuffer == ToolCallEnd) { + m_toolCall.chop(ToolCallEnd.size()); + m_state = ToolEnums::ParseState::Complete; + m_endTagBuffer.clear(); + } + } + case ToolEnums::ParseState::Complete: + { + // Already complete, do nothing further + break; + } + } + } +} + +QPair ToolCallParser::split() +{ + Q_ASSERT(m_state == ToolEnums::ParseState::Partial + || m_state == ToolEnums::ParseState::Complete); + + Q_ASSERT(m_startIndex >= 0); + m_hasSplit = true; + const QString beforeToolCall = m_buffer.left(m_startIndex); + m_buffer = m_buffer.mid(m_startIndex); + m_startIndex = 0; + return { beforeToolCall, m_buffer }; +} diff --git a/gpt4all-chat/src/toolcallparser.h b/gpt4all-chat/src/toolcallparser.h new file mode 100644 index 000000000000..df4d8d22212b --- /dev/null +++ b/gpt4all-chat/src/toolcallparser.h @@ -0,0 +1,45 @@ +#ifndef TOOLCALLPARSER_H +#define TOOLCALLPARSER_H + +#include + +#include "tool.h" + +namespace ToolCallConstants +{ + const QString CodeInterpreterFunction = R"(javascript_interpret)"; + const QString CodeInterpreterTag = R"(<)" + CodeInterpreterFunction + R"(>)"; + const QString CodeInterpreterEndTag = R"()"; + const QString CodeInterpreterPrefix = CodeInterpreterTag + "\n```javascript\n"; + const QString CodeInterpreterSuffix = "```\n" + CodeInterpreterEndTag; +} + +class ToolCallParser +{ +public: + ToolCallParser(); + void reset(); + void update(const QString &update); + QString buffer() const { return m_buffer; } + QString toolCall() const { return m_toolCall; } + int startIndex() const { return m_startIndex; } + ToolEnums::ParseState state() const { return m_state; } + + // Splits + QPair split(); + bool hasSplit() const { return m_hasSplit; } + +private: + void resetSearchState(); + + QChar m_expected; + int m_expectedIndex; + ToolEnums::ParseState m_state; + QString m_buffer; + QString m_toolCall; + QString m_endTagBuffer; + int m_startIndex; + bool m_hasSplit; +}; + +#endif // TOOLCALLPARSER_H diff --git a/gpt4all-chat/src/toolmodel.cpp b/gpt4all-chat/src/toolmodel.cpp new file mode 100644 index 000000000000..3d9b2eab4a66 --- /dev/null +++ b/gpt4all-chat/src/toolmodel.cpp @@ -0,0 +1,31 @@ +#include "toolmodel.h" + +#include +#include +#include + +#include "codeinterpreter.h" + +class MyToolModel: public ToolModel { }; +Q_GLOBAL_STATIC(MyToolModel, toolModelInstance) +ToolModel *ToolModel::globalInstance() +{ + return toolModelInstance(); +} + +ToolModel::ToolModel() + : QAbstractListModel(nullptr) { + + QCoreApplication::instance()->installEventFilter(this); + + Tool* codeInterpreter = new CodeInterpreter; + m_tools.append(codeInterpreter); + m_toolMap.insert(codeInterpreter->function(), codeInterpreter); +} + +bool ToolModel::eventFilter(QObject *obj, QEvent *ev) +{ + if (obj == QCoreApplication::instance() && ev->type() == QEvent::LanguageChange) + emit dataChanged(index(0, 0), index(m_tools.size() - 1, 0)); + return false; +} diff --git a/gpt4all-chat/src/toolmodel.h b/gpt4all-chat/src/toolmodel.h new file mode 100644 index 000000000000..0a584c907929 --- /dev/null +++ b/gpt4all-chat/src/toolmodel.h @@ -0,0 +1,104 @@ +#ifndef TOOLMODEL_H +#define TOOLMODEL_H + +#include "tool.h" + +#include + +class ToolModel : public QAbstractListModel +{ + Q_OBJECT + Q_PROPERTY(int count READ count NOTIFY countChanged) + +public: + static ToolModel *globalInstance(); + + enum Roles { + NameRole = Qt::UserRole + 1, + DescriptionRole, + FunctionRole, + ParametersRole, + SymbolicFormatRole, + ExamplePromptRole, + ExampleCallRole, + ExampleReplyRole, + }; + + int rowCount(const QModelIndex &parent = QModelIndex()) const override + { + Q_UNUSED(parent) + return m_tools.size(); + } + + QVariant data(const QModelIndex &index, int role = Qt::DisplayRole) const override + { + if (!index.isValid() || index.row() < 0 || index.row() >= m_tools.size()) + return QVariant(); + + const Tool *item = m_tools.at(index.row()); + switch (role) { + case NameRole: + return item->name(); + case DescriptionRole: + return item->description(); + case FunctionRole: + return item->function(); + case ParametersRole: + return QVariant::fromValue(item->parameters()); + case SymbolicFormatRole: + return item->symbolicFormat(); + case ExamplePromptRole: + return item->examplePrompt(); + case ExampleCallRole: + return item->exampleCall(); + case ExampleReplyRole: + return item->exampleReply(); + } + + return QVariant(); + } + + QHash roleNames() const override + { + QHash roles; + roles[NameRole] = "name"; + roles[DescriptionRole] = "description"; + roles[FunctionRole] = "function"; + roles[ParametersRole] = "parameters"; + roles[SymbolicFormatRole] = "symbolicFormat"; + roles[ExamplePromptRole] = "examplePrompt"; + roles[ExampleCallRole] = "exampleCall"; + roles[ExampleReplyRole] = "exampleReply"; + return roles; + } + + Q_INVOKABLE Tool* get(int index) const + { + if (index < 0 || index >= m_tools.size()) return nullptr; + return m_tools.at(index); + } + + Q_INVOKABLE Tool *get(const QString &id) const + { + if (!m_toolMap.contains(id)) return nullptr; + return m_toolMap.value(id); + } + + int count() const { return m_tools.size(); } + +Q_SIGNALS: + void countChanged(); + void valueChanged(int index, const QString &value); + +protected: + bool eventFilter(QObject *obj, QEvent *ev) override; + +private: + explicit ToolModel(); + ~ToolModel() {} + friend class MyToolModel; + QList m_tools; + QHash m_toolMap; +}; + +#endif // TOOLMODEL_H