diff --git a/lua/model/core/chat.lua b/lua/model/core/chat.lua index f105686..7b798ca 100644 --- a/lua/model/core/chat.lua +++ b/lua/model/core/chat.lua @@ -8,6 +8,7 @@ local M = {} ---@field provider Provider The API provider for this prompt ---@field create fun(input: string, context: Context): string | ChatContents Converts input and context to the first message text or ChatContents ---@field run fun(messages: ChatMessage[], config: ChatConfig): table | fun(resolve: fun(params: table): nil ) ) Converts chat messages and config into completion request params +---@field runOptions? fun(): table Builds additional options to merge into chat prompt options. E.g. for auth tokens that shouldn't be written to the chat config header. ---@field system? string System instruction ---@field params? table Static request parameters ---@field options? table Provider options @@ -336,6 +337,10 @@ function M.run_chat(opts) local options = parsed.contents.config.options or {} local params = parsed.contents.config.params or {} + if type(chat_prompt.runOptions) == 'function' then + options = vim.tbl_deep_extend('force', options, chat_prompt.runOptions()) + end + if type(run_params) == 'function' then run_params(function(async_params) local merged_params = vim.tbl_deep_extend('force', params, async_params) diff --git a/lua/model/prompts/chats.lua b/lua/model/prompts/chats.lua index 7ded14c..964c6b6 100644 --- a/lua/model/prompts/chats.lua +++ b/lua/model/prompts/chats.lua @@ -9,6 +9,8 @@ local anthropic = require('model.providers.anthropic') local zephyr_fmt = require('model.format.zephyr') local starling_fmt = require('model.format.starling') +local util = require('model.util') + local function input_if_selection(input, context) return context.selection and input or '' end @@ -213,6 +215,17 @@ local chats = { }) end, }, + groq = vim.tbl_deep_extend('force', openai_chat, { + params = { + model = 'llama3-70b-8192', + }, + runOptions = function() + return { + url = 'https://api.groq.com/openai/v1/', + authorization = 'Bearer ' .. util.env('GROQ_API_KEY'), + } + end, + }), } return chats diff --git a/lua/model/util/init.lua b/lua/model/util/init.lua index 8e0aa41..f03ade2 100644 --- a/lua/model/util/init.lua +++ b/lua/model/util/init.lua @@ -61,7 +61,7 @@ local get_secret_once = M.memo(function(name) end) function M.env(name) - if M.secrets[name] then + if type(M.secrets[name]) == 'function' then return get_secret_once(name) else local value = vim.env[name]