Skip to content

Commit

Permalink
Add streaming for model response
Browse files Browse the repository at this point in the history
  • Loading branch information
vietanhdev committed Sep 29, 2024
1 parent 74ddec6 commit 2f993c1
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 112 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ This assistant can run offline on your local machine, and it respects your priva
- [x] Custom models: Add support for custom models.
- [x] 📚 Support 5 other text models.
- [x] 🖼️ Support 5 other multimodal models.
- [x] ⚡ Streaming support for response.
- [ ] 🎙️ Add offline STT support: WhisperCPP. [Experimental Code](llama_assistant/speech_recognition_whisper_experimental.py).
- [ ] 🧠 Knowledge database: Langchain or LlamaIndex?.
- [ ] 🔌 Plugin system for extensibility.
Expand Down Expand Up @@ -158,7 +159,6 @@ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file

## Acknowledgements

- [Radio icons created by Freepik - Flaticon](https://www.flaticon.com/free-icons/radio)
- [Llama 3.2](https://github.com/facebookresearch/llama) by Meta AI Research

## Star History
Expand Down
9 changes: 9 additions & 0 deletions llama_assistant/icons.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
</svg>
"""

microphone_icon_svg = """
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none" stroke="#fff" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M12 1a3 3 0 0 0-3 3v8a3 3 0 0 0 6 0V4a3 3 0 0 0-3-3z"></path>
<path d="M19 10v2a7 7 0 0 1-14 0v-2"></path>
<line x1="12" y1="19" x2="12" y2="23"></line>
<line x1="8" y1="23" x2="16" y2="23"></line>
</svg>
"""


def create_icon_from_svg(svg_string):
svg_bytes = QByteArray(svg_string.encode("utf-8"))
Expand Down
107 changes: 71 additions & 36 deletions llama_assistant/llama_assistant.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import json
import markdown
from pathlib import Path
from importlib import resources

from PyQt6.QtWidgets import (
QApplication,
QMainWindow,
Expand All @@ -21,6 +19,8 @@
QPoint,
QSize,
QTimer,
QThread,
pyqtSignal,
)
from PyQt6.QtGui import (
QIcon,
Expand All @@ -35,12 +35,11 @@
QDropEvent,
QFont,
QBitmap,
QTextCursor,
)
from llama_assistant.wake_word_detector import WakeWordDetector

from llama_assistant.custom_plaintext_editor import CustomPlainTextEdit
from llama_assistant.global_hotkey import GlobalHotkey
from llama_assistant.loading_animation import LoadingAnimation
from llama_assistant.setting_dialog import SettingsDialog
from llama_assistant.speech_recognition import SpeechRecognitionThread
from llama_assistant.utils import image_to_base64_data_uri
Expand All @@ -49,9 +48,34 @@
create_icon_from_svg,
copy_icon_svg,
clear_icon_svg,
microphone_icon_svg,
)


class ProcessingThread(QThread):
update_signal = pyqtSignal(str)
finished_signal = pyqtSignal()

def __init__(self, model, prompt, image=None):
super().__init__()
self.model = model
self.prompt = prompt
self.image = image

def run(self):
output = model_handler.chat_completion(
self.model, self.prompt, image=self.image, stream=True
)
for chunk in output:
delta = chunk["choices"][0]["delta"]
if "role" in delta:
print(delta["role"], end=": ")
elif "content" in delta:
print(delta["content"], end="")
self.update_signal.emit(delta["content"])
self.finished_signal.emit()


class LlamaAssistant(QMainWindow):
def __init__(self):
super().__init__()
Expand All @@ -67,6 +91,8 @@ def __init__(self):
self.image_label = None
self.current_text_model = self.settings.get("text_model")
self.current_multimodal_model = self.settings.get("multimodal_model")
self.processing_thread = None
self.response_start_position = 0

def init_wake_word_detector(self):
if self.wake_word_detector is not None:
Expand Down Expand Up @@ -180,23 +206,19 @@ def init_ui(self):
)
top_layout.addWidget(self.input_field)

# Load the mic icon from resources
with resources.path("llama_assistant.resources", "mic_icon.png") as path:
mic_icon = QIcon(str(path))

self.mic_button = QPushButton(self)
self.mic_button.setIcon(mic_icon)
self.mic_button.setIcon(create_icon_from_svg(microphone_icon_svg))
self.mic_button.setIconSize(QSize(24, 24))
self.mic_button.setFixedSize(40, 40)
self.mic_button.setStyleSheet(
"""
QPushButton {
background-color: rgba(255, 255, 255, 0.1);
background-color: rgba(255, 255, 255, 0.3);
border: none;
border-radius: 20px;
}
QPushButton:hover {
background-color: rgba(255, 255, 255, 0.2);
background-color: rgba(255, 255, 255, 0.5);
}
"""
)
Expand Down Expand Up @@ -290,6 +312,7 @@ def init_ui(self):
QScrollArea {
border: none;
background-color: transparent;
border-radius: 10px;
}
QScrollBar:vertical {
border: none;
Expand All @@ -315,10 +338,6 @@ def init_ui(self):
self.scroll_area.hide()
main_layout.addWidget(self.scroll_area)

self.loading_animation = LoadingAnimation(self)
self.loading_animation.setFixedSize(50, 50)
self.loading_animation.hide()

self.oldPos = self.pos()

self.center_on_screen()
Expand Down Expand Up @@ -354,7 +373,7 @@ def update_styles(self):
self.chat_box.setStyleSheet(
f"""QTextBrowser {{ {base_style}
background-color: rgba{QColor(self.settings["color"]).lighter(120).getRgb()[:3] + (opacity,)};
border-radius: 5px;
border-radius: 10px;
}}"""
)
button_style = f"""
Expand Down Expand Up @@ -441,8 +460,6 @@ def toggle_visibility(self):
def on_submit(self):
message = self.input_field.toPlainText()
self.input_field.clear()
self.loading_animation.move(self.width() // 2 - 25, self.height() // 2 - 25)
self.loading_animation.start_animation()

if self.dropped_image:
self.process_image_with_prompt(self.dropped_image, message)
Expand All @@ -452,6 +469,7 @@ def on_submit(self):
QTimer.singleShot(100, lambda: self.process_text(message))

def process_text(self, message, task="chat"):
self.show_chat_box()
if task == "chat":
prompt = message + " \n" + "Generate a short and simple response."
elif task == "summarize":
Expand All @@ -465,32 +483,49 @@ def process_text(self, message, task="chat"):
elif task == "write email":
prompt = f"Write an email about: {message}"

response = model_handler.chat_completion(self.current_text_model, prompt)
self.last_response = response

self.chat_box.append(f"<b>You:</b> {message}")
self.chat_box.append(f"<b>AI ({task}):</b> {markdown.markdown(response)}")
self.loading_animation.stop_animation()
self.show_chat_box()
self.chat_box.append(f"<b>AI ({task}):</b> ")

self.processing_thread = ProcessingThread(self.current_text_model, prompt)
self.processing_thread.update_signal.connect(self.update_chat_box)
self.processing_thread.finished_signal.connect(self.on_processing_finished)
self.processing_thread.start()

def process_image_with_prompt(self, image_path, prompt):
response = model_handler.chat_completion(
self.current_multimodal_model, prompt, image=image_to_base64_data_uri(image_path)
)
self.show_chat_box()
self.chat_box.append(f"<b>You:</b> [Uploaded an image: {image_path}]")
self.chat_box.append(f"<b>You:</b> {prompt}")
self.chat_box.append(
f"<b>AI:</b> {markdown.markdown(response)}" if response else "No response"
self.chat_box.append("<b>AI:</b> ")

image = image_to_base64_data_uri(image_path)
self.processing_thread = ProcessingThread(
self.current_multimodal_model, prompt, image=image
)
self.loading_animation.stop_animation()
self.show_chat_box()
self.processing_thread.update_signal.connect(self.update_chat_box)
self.processing_thread.finished_signal.connect(self.on_processing_finished)
self.processing_thread.start()

def update_chat_box(self, text):
self.chat_box.textCursor().insertText(text)
self.chat_box.verticalScrollBar().setValue(self.chat_box.verticalScrollBar().maximum())
self.last_response += text

def on_processing_finished(self):
# Clear the last_response for the next interaction
self.last_response = ""

# Reset the response start position
self.response_start_position = 0

# New line for the next interaction
self.chat_box.append("")

def show_chat_box(self):
if self.scroll_area.isHidden():
self.scroll_area.show()
self.copy_button.show()
self.clear_button.show()
self.setFixedHeight(600) # Increase this value if needed
self.setFixedHeight(500) # Increase this value if needed
self.chat_box.verticalScrollBar().setValue(self.chat_box.verticalScrollBar().maximum())

def copy_result(self):
Expand Down Expand Up @@ -617,12 +652,12 @@ def start_voice_input(self):
self.mic_button.setStyleSheet(
"""
QPushButton {
background-color: rgba(255, 0, 0, 0.3);
background-color: rgba(255, 0, 0, 0.5);
border: none;
border-radius: 20px;
}
QPushButton:hover {
background-color: rgba(255, 0, 0, 0.5);
background-color: rgba(255, 0, 0, 0.6);
}
"""
)
Expand All @@ -639,12 +674,12 @@ def stop_voice_input(self):
self.mic_button.setStyleSheet(
"""
QPushButton {
background-color: rgba(255, 255, 255, 0.1);
background-color: rgba(255, 255, 255, 0.5);
border: none;
border-radius: 20px;
}
QPushButton:hover {
background-color: rgba(255, 255, 255, 0.2);
background-color: rgba(255, 255, 255, 0.6);
}
"""
)
Expand Down
68 changes: 0 additions & 68 deletions llama_assistant/loading_animation.py

This file was deleted.

11 changes: 7 additions & 4 deletions llama_assistant/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def chat_completion(
model_id: str,
message: str,
image: Optional[str] = None,
n_ctx: int = 2048,
stream: bool = False,
) -> str:
model_data = self.load_model(model_id)
if not model_data:
Expand All @@ -168,12 +168,15 @@ def chat_completion(
{"type": "image_url", "image_url": {"url": image}},
],
}
]
],
stream=stream,
)
else:
response = model.create_chat_completion(messages=[{"role": "user", "content": message}])
response = model.create_chat_completion(
messages=[{"role": "user", "content": message}], stream=stream
)

return response["choices"][0]["message"]["content"]
return response

def _schedule_unload(self):
if self.unload_timer:
Expand Down
Binary file removed llama_assistant/resources/mic_icon.png
Binary file not shown.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llama-assistant"
version = "0.1.19"
version = "0.1.20"
authors = [
{name = "Viet-Anh Nguyen", email = "[email protected]"},
]
Expand Down Expand Up @@ -52,7 +52,7 @@ include = ["llama_assistant*"]
exclude = ["tests*"]

[tool.setuptools.package-data]
"llama_assistant.resources" = ["*.png", "*.onnx"]
"llama_assistant.resources" = ["*.onnx"]


[tool.black]
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ python_requires = >=3.8
where = .

[options.package_data]
llama_assistant.resources = *.png, *.onnx
llama_assistant.resources = *.onnx

0 comments on commit 2f993c1

Please sign in to comment.