From 77e891d76f6e32bbb8e9dbc29db516dfa506b4f4 Mon Sep 17 00:00:00 2001 From: Steven Sun Date: Sat, 25 May 2024 12:04:10 -0500 Subject: [PATCH] feat: add runOption, use with new groq chat groq is openai compatible but requires auth token, this lets us add an auth token after parsing chat prompt so we don't need to stick tokens in plaintext --- lua/model/core/chat.lua | 5 +++++ lua/model/prompts/chats.lua | 13 +++++++++++++ lua/model/util/init.lua | 2 +- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/lua/model/core/chat.lua b/lua/model/core/chat.lua index f105686..7b798ca 100644 --- a/lua/model/core/chat.lua +++ b/lua/model/core/chat.lua @@ -8,6 +8,7 @@ local M = {} ---@field provider Provider The API provider for this prompt ---@field create fun(input: string, context: Context): string | ChatContents Converts input and context to the first message text or ChatContents ---@field run fun(messages: ChatMessage[], config: ChatConfig): table | fun(resolve: fun(params: table): nil ) ) Converts chat messages and config into completion request params +---@field runOptions? fun(): table Builds additional options to merge into chat prompt options. E.g. for auth tokens that shouldn't be written to the chat config header. ---@field system? string System instruction ---@field params? table Static request parameters ---@field options? table Provider options @@ -336,6 +337,10 @@ function M.run_chat(opts) local options = parsed.contents.config.options or {} local params = parsed.contents.config.params or {} + if type(chat_prompt.runOptions) == 'function' then + options = vim.tbl_deep_extend('force', options, chat_prompt.runOptions()) + end + if type(run_params) == 'function' then run_params(function(async_params) local merged_params = vim.tbl_deep_extend('force', params, async_params) diff --git a/lua/model/prompts/chats.lua b/lua/model/prompts/chats.lua index 7ded14c..964c6b6 100644 --- a/lua/model/prompts/chats.lua +++ b/lua/model/prompts/chats.lua @@ -9,6 +9,8 @@ local anthropic = require('model.providers.anthropic') local zephyr_fmt = require('model.format.zephyr') local starling_fmt = require('model.format.starling') +local util = require('model.util') + local function input_if_selection(input, context) return context.selection and input or '' end @@ -213,6 +215,17 @@ local chats = { }) end, }, + groq = vim.tbl_deep_extend('force', openai_chat, { + params = { + model = 'llama3-70b-8192', + }, + runOptions = function() + return { + url = 'https://api.groq.com/openai/v1/', + authorization = 'Bearer ' .. util.env('GROQ_API_KEY'), + } + end, + }), } return chats diff --git a/lua/model/util/init.lua b/lua/model/util/init.lua index 8e0aa41..f03ade2 100644 --- a/lua/model/util/init.lua +++ b/lua/model/util/init.lua @@ -61,7 +61,7 @@ local get_secret_once = M.memo(function(name) end) function M.env(name) - if M.secrets[name] then + if type(M.secrets[name]) == 'function' then return get_secret_once(name) else local value = vim.env[name]