Skip to content

Commit

Permalink
enable per org model family selection
Browse files Browse the repository at this point in the history
  • Loading branch information
janaka committed Sep 25, 2023
1 parent 4b9d300 commit 4c1c5d6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 21 deletions.
6 changes: 3 additions & 3 deletions source/docq/manage_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def _init() -> None:
connection.commit()


def _create_index(documents: List[Document]) -> GPTVectorStoreIndex:
def _create_index(documents: List[Document], org_id: int) -> GPTVectorStoreIndex:
# Use default storage and service context to initialise index purely for persisting
return GPTVectorStoreIndex.from_documents(
documents, storage_context=_get_default_storage_context(), service_context=_get_service_context()
documents, storage_context=_get_default_storage_context(), service_context=_get_service_context(org_id)
)


Expand All @@ -71,7 +71,7 @@ def reindex(space: SpaceKey) -> None:
log.debug("get datasource instance")
documents = SpaceDataSources[ds_type].value.load(space, ds_configs)
log.debug("docs to index, %s", len(documents))
index = _create_index(documents)
index = _create_index(documents, space.org_id)
_persist_index(index, space)
except Exception as e:
log.exception("Error indexing space %s: %s", space, e)
Expand Down
4 changes: 2 additions & 2 deletions source/docq/model_selection/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ class ModelUsageSettings:

def get_selected_model_settings(org_id: int) -> dict:
"""Get the settings for the saved model."""
saved_setting = get_organisation_settings(org_id, SystemSettingsKey.MODEL_VENDOR.value)
saved_setting = get_organisation_settings(org_id, SystemSettingsKey.MODEL_VENDOR)

return LLM_MODELS[saved_setting] if saved_setting else LLM_MODELS[ModelVendors.AZURE_OPENAI]


def set_selected_model(org_id: int, model_vendor: ModelVendors) -> None:
"""Save the selected model."""
update_organisation_settings(org_id, SystemSettingsKey.MODEL_VENDOR.name, model_vendor.value)
update_organisation_settings(SystemSettingsKey.MODEL_VENDOR.name, model_vendor.value, org_id)

log.debug("Selected Model: %s", model_vendor)

Expand Down
2 changes: 1 addition & 1 deletion source/docq/run_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def query(
history = _retrieve_last_n_history(feature, thread_id)
log.debug("is_chat: %s", is_chat)
try:
response = run_chat(input_, history) if is_chat else run_ask(input_, history, space, spaces)
response = run_chat(input_, history, space.org_id) if is_chat else run_ask(input_, history, space, spaces)
log.debug("Response: %s", response)

except Exception as e:
Expand Down
33 changes: 18 additions & 15 deletions source/docq/support/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@
# return OpenAI(temperature=0, model_name="text-davinci-003")


def _get_chat_model() -> ChatOpenAI:
selected_model = get_selected_model_settings()
def _get_chat_model(org_id: int) -> ChatOpenAI:
selected_model = get_selected_model_settings(org_id)

if selected_model and selected_model["CHAT"]:
if selected_model["CHAT"].model_vendor == ModelVendors.AZURE_OPENAI:
Expand All @@ -95,8 +95,8 @@ def _get_chat_model() -> ChatOpenAI:
return model


def _get_embed_model() -> LangchainEmbedding:
selected_model = get_selected_model_settings()
def _get_embed_model(org_id: int) -> LangchainEmbedding:
selected_model = get_selected_model_settings(org_id)
if selected_model and selected_model["EMBED"]:
if selected_model["EMBED"].model_vendor == ModelVendors.AZURE_OPENAI:
embedding_llm = LangchainEmbedding(
Expand Down Expand Up @@ -124,8 +124,8 @@ def _get_embed_model() -> LangchainEmbedding:
return embedding_llm


def _get_llm_predictor() -> LLMPredictor:
return LLMPredictor(llm=_get_chat_model())
def _get_llm_predictor(org_id: int) -> LLMPredictor:
return LLMPredictor(llm=_get_chat_model(org_id))


def _get_default_storage_context() -> StorageContext:
Expand All @@ -136,21 +136,21 @@ def _get_storage_context(space: SpaceKey) -> StorageContext:
return StorageContext.from_defaults(persist_dir=get_index_dir(space))


def _get_service_context() -> ServiceContext:
def _get_service_context(org_id: int) -> ServiceContext:
log.debug(
"EXPERIMENTS['INCLUDE_EXTRACTED_METADATA']['enabled']: %s", EXPERIMENTS["INCLUDE_EXTRACTED_METADATA"]["enabled"]
)

if EXPERIMENTS["INCLUDE_EXTRACTED_METADATA"]["enabled"]:
return ServiceContext.from_defaults(
llm_predictor=_get_llm_predictor(),
llm_predictor=_get_llm_predictor(org_id),
node_parser=_get_node_parser(),
embed_model=_get_embed_model(),
embed_model=_get_embed_model(org_id),
)
else:
return ServiceContext.from_defaults(
llm_predictor=_get_llm_predictor(),
embed_model=_get_embed_model(),
llm_predictor=_get_llm_predictor(org_id),
embed_model=_get_embed_model(org_id),
)


Expand All @@ -176,10 +176,13 @@ def _get_node_parser() -> SimpleNodeParser:

def _load_index_from_storage(space: SpaceKey) -> GPTVectorStoreIndex:
# set service context explicitly for multi model compatibility
return load_index_from_storage(storage_context=_get_storage_context(space), service_context=_get_service_context())

return load_index_from_storage(
storage_context=_get_storage_context(space), service_context=_get_service_context(space.org_id)
)


def run_chat(input_: str, history: str) -> BaseMessage:
def run_chat(input_: str, history: str, org_id: int) -> BaseMessage:
"""Chat directly with a LLM with history."""
# prompt = ChatPromptTemplate.from_messages(
# [
Expand All @@ -188,7 +191,7 @@ def run_chat(input_: str, history: str) -> BaseMessage:
# ]
# )
# output = _get_chat_model()(prompt.format_prompt(history=history, input=input_).to_messages())
engine = SimpleChatEngine.from_defaults(service_context=_get_service_context())
engine = SimpleChatEngine.from_defaults(service_context=_get_service_context(org_id))
output = engine.chat(input_)

log.debug("(Chat) Q: %s, A: %s", input_, output)
Expand Down Expand Up @@ -256,7 +259,7 @@ def run_ask(input_: str, history: str, space: SpaceKey = None, spaces: list[Spac
# No additional spaces i.e. likely to be against a user's documents in their personal space.
if space is None:
log.debug("runs_ask(): space is None. executing run_chat(), not ASK.")
output = run_chat(input_, history)
output = run_chat(input_, history, space.org_id)
else:
index = _load_index_from_storage(space)
engine = index.as_chat_engine(
Expand Down

0 comments on commit 4c1c5d6

Please sign in to comment.