Skip to content

Commit

Permalink
feat: add test connection for edit spec (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
phv2312 authored Sep 8, 2024
1 parent fa881d4 commit dbb6bb2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
15 changes: 11 additions & 4 deletions libs/ktem/ktem/embeddings/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import yaml
from ktem.app import BasePage
from ktem.utils.file import YAMLNoDateSafeLoader
from theflow.utils.modules import deserialize

from .manager import embedding_models_manager

Expand Down Expand Up @@ -237,7 +238,7 @@ def on_register_events(self):

self.btn_test_connection.click(
self.check_connection,
inputs=[self.selected_emb_name],
inputs=[self.selected_emb_name, self.edit_spec],
outputs=[self.connection_logs],
)

Expand Down Expand Up @@ -330,14 +331,20 @@ def on_btn_delete_click(self):

return btn_delete, btn_delete_yes, btn_delete_no

def check_connection(self, selected_emb_name):
def check_connection(self, selected_emb_name, selected_spec):
log_content: str = ""

try:
log_content += f"- Testing model: {selected_emb_name}<br>"
yield log_content

emb = embedding_models_manager.get(selected_emb_name)
# Parse content & init model
info = deepcopy(embedding_models_manager.info()[selected_emb_name])

# Parse content & create dummy embedding
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
info["spec"].update(spec)

emb = deserialize(info["spec"], safe=False)

if emb is None:
raise Exception(f"Can not found model: {selected_emb_name}")
Expand Down
18 changes: 14 additions & 4 deletions libs/ktem/ktem/llms/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import yaml
from ktem.app import BasePage
from ktem.utils.file import YAMLNoDateSafeLoader
from theflow.utils.modules import deserialize

from .manager import llms

Expand Down Expand Up @@ -59,7 +60,9 @@ def on_building_ui(self):
self.connection_logs = gr.HTML("Logs")

with gr.Column(scale=1):
self.btn_test_connection = gr.Button("Test")
self.btn_test_connection = gr.Button(
"Test",
)

with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
Expand Down Expand Up @@ -233,7 +236,7 @@ def on_register_events(self):

self.btn_test_connection.click(
self.check_connection,
inputs=[self.selected_llm_name],
inputs=[self.selected_llm_name, self.edit_spec],
outputs=[self.connection_logs],
)

Expand Down Expand Up @@ -326,14 +329,21 @@ def on_btn_delete_click(self):

return btn_delete, btn_delete_yes, btn_delete_no

def check_connection(self, selected_llm_name: str):
def check_connection(self, selected_llm_name: str, selected_spec):
log_content: str = ""

try:
log_content += f"- Testing model: {selected_llm_name}<br>"
yield log_content

llm = llms.get(key=selected_llm_name, default=None)
# Parse content & init model
info = deepcopy(llms.info()[selected_llm_name])

# Parse content & create dummy embedding
spec = yaml.load(selected_spec, Loader=YAMLNoDateSafeLoader)
info["spec"].update(spec)

llm = deserialize(info["spec"], safe=False)

if llm is None:
raise Exception(f"Can not found model: {selected_llm_name}")
Expand Down

0 comments on commit dbb6bb2

Please sign in to comment.