From 2eb5c45f4ddae5f686d766098470f59fdf8bccf2 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 22 Apr 2024 13:16:04 +0100 Subject: [PATCH 01/24] feat(ai-proxy): complete refactor of streaming subsystem --- kong/llm/drivers/anthropic.lua | 141 ++++++++- kong/llm/drivers/cohere.lua | 160 ++++------ kong/llm/drivers/openai.lua | 12 +- kong/llm/drivers/shared.lua | 154 ++++++++- kong/llm/init.lua | 251 --------------- kong/plugins/ai-proxy/handler.lua | 295 ++++++++++++++---- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 13 +- .../09-streaming_integration_spec.lua | 209 +++++++++++-- .../llm-v1-chat/requests/good-stream.json | 13 + .../anthropic/llm-v1-chat.json | 3 +- .../anthropic/llm-v1-completions.json | 3 +- .../expected-requests/cohere/llm-v1-chat.json | 6 +- .../cohere/llm-v1-completions.json | 3 +- .../openai/llm-v1-completions.txt | 2 +- 14 files changed, 798 insertions(+), 467 deletions(-) create mode 100644 spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good-stream.json diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index e41f6dd9d7f..ec6a717ef80 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -92,9 +92,10 @@ local transformers_to = { return nil, nil, err end - messages.temperature = (model.options and model.options.temperature) or nil - messages.max_tokens = (model.options and model.options.max_tokens) or nil - messages.model = model.name + messages.temperature = request_table.temperature or (model.options and model.options.temperature) or nil + messages.max_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) or nil + messages.model = model.name or request_table.model + messages.stream = request_table.stream or false -- explicitly set this if nil return messages, "application/json", nil end, @@ -108,14 +109,140 @@ local transformers_to = { return nil, nil, err end - prompt.temperature = (model.options and model.options.temperature) or nil - prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or nil + prompt.temperature = request_table.temperature or (model.options and model.options.temperature) or nil + prompt.max_tokens_to_sample = request_table.max_tokens or (model.options and model.options.max_tokens) or nil prompt.model = model.name + prompt.model = model.name or request_table.model + prompt.stream = request_table.stream or false -- explicitly set this if nil return prompt, "application/json", nil end, } +local function delta_to_event(delta, model_info) + local data = { + choices = { + [1] = { + delta = { + content = (delta.delta + and delta.delta.text) + or (delta.content_block + and "") + or "", + }, + index = 0, + finish_reason = cjson.null, + logprobs = cjson.null, + }, + }, + id = kong + and kong.ctx + and kong.ctx.plugin + and kong.ctx.plugin.ai_proxy_anthropic_stream_id, + model = model_info.name, + object = "chat.completion.chunk", + } + + return cjson.encode(data), nil, nil +end + +local function start_to_event(event_data, model_info) + local meta = event_data.message or {} + + local metadata = { + prompt_tokens = meta.usage + and meta.usage.input_tokens + or nil, + completion_tokens = meta.usage + and meta.usage.output_tokens + or nil, + model = meta.model, + stop_reason = meta.stop_reason, + stop_sequence = meta.stop_sequence, + } + + local message = { + choices = { + [1] = { + delta = { + content = "", + role = meta.role, + }, + index = 0, + logprobs = cjson.null, + }, + }, + id = meta.id, + model = model_info.name, + object = "chat.completion.chunk", + system_fingerprint = cjson.null, + } + + message = cjson.encode(message) + kong.ctx.plugin.ai_proxy_anthropic_stream_id = meta.id + + return message, nil, metadata +end + +local function handle_stream_event(event_t, model_info, route_type) + ngx.log(ngx.WARN, event_t.event or "NO EVENT") + ngx.log(ngx.WARN, event_t.data or "NO DATA") + local event_id = event_t.event or "ping" + local event_data = cjson.decode(event_t.data) + + if event_id and event_data then + if event_id == "message_start" then + -- message_start and contains the token usage and model metadata + + if event_data and event_data.message then + return start_to_event(event_data, model_info) + else + return nil, "message_start is missing the metadata block", nil + end + + elseif event_id == "message_delta" then + -- message_delta contains and interim token count of the + -- last few frames / iterations + if event_data + and event_data.usage then + local meta = event_data.usage + + return nil, nil, { + prompt_tokens = nil, + completion_tokens = event_data.meta.usage + and event_data.meta.usage.output_tokens + or nil, + stop_reason = event_data.delta + and event_data.delta.stop_reason + or nil, + stop_sequence = event_data.delta + and event_data.delta.stop_sequence + or nil, + } + else + return nil, "message_delta is missing the metadata block", nil + end + + elseif event_id == "content_block_start" then + -- content_block_start is just an empty string and indicates + -- that we're getting an actual answer + return delta_to_event(event_data, model_info) + + elseif event_id == "content_block_delta" then + return delta_to_event(event_data, model_info) + + elseif event_id == "message_stop" then + return "[DONE]", nil, nil + + elseif event_id == "ping" then + return nil, nil, nil + + end + end + + return nil, "transformation to stream event failed or empty stream event received", nil +end + local transformers_from = { ["llm/v1/chat"] = function(response_string) local response_table, err = cjson.decode(response_string) @@ -199,6 +326,8 @@ local transformers_from = { return nil, "'completion' not in anthropic://llm/v1/chat response" end end, + + ["stream/llm/v1/chat"] = handle_stream_event, } function _M.from_format(response_string, model_info, route_type) @@ -210,7 +339,7 @@ function _M.from_format(response_string, model_info, route_type) return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type) end - local ok, response_string, err = pcall(transform, response_string) + local ok, response_string, err = pcall(transform, response_string, model_info, route_type) if not ok or err then return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 2788c749b46..5b519b17807 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -12,19 +12,19 @@ local table_new = require("table.new") local DRIVER_NAME = "cohere" -- -local function handle_stream_event(event_string, model_info, route_type) +local function handle_stream_event(event_t, model_info, route_type) local metadata - + -- discard empty frames, it should either be a random new line, or comment - if #event_string < 1 then + if (not event_t.data) or (#event_t.data < 1) then return end - local event, err = cjson.decode(event_string) + local event, err = cjson.decode(event_t.data) if err then return nil, "failed to decode event frame from cohere: " .. err, nil end - + local new_event if event.event_type == "stream-start" then @@ -89,11 +89,10 @@ local function handle_stream_event(event_string, model_info, route_type) end elseif event.event_type == "stream-end" then - -- return a metadata object, with a null event - metadata = { - -- prompt_tokens = event.response.token_count.prompt_tokens, - -- completion_tokens = event.response.token_count.response_tokens, + -- return a metadata object, with the OpenAI termination event + new_event = "[DONE]" + metadata = { completion_tokens = event.response and event.response.meta and event.response.meta.billed_units @@ -114,113 +113,82 @@ local function handle_stream_event(event_string, model_info, route_type) and event.token_count.prompt_tokens or 0, } - end if new_event then - new_event = cjson.encode(new_event) + if new_event ~= "[DONE]" then + new_event = cjson.encode(new_event) + end + return new_event, nil, metadata else return nil, nil, metadata -- caller code will handle "unrecognised" event types end end -local transformers_to = { - ["llm/v1/chat"] = function(request_table, model) - request_table.model = model.name - if request_table.prompt and request_table.messages then - return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema") - - elseif request_table.messages then - -- we have to move all BUT THE LAST message into "chat_history" array - -- and move the LAST message (from 'user') into "message" string - if #request_table.messages > 1 then - local chat_history = table_new(#request_table.messages - 1, 0) - for i, v in ipairs(request_table.messages) do - -- if this is the last message prompt, don't add to history - if i < #request_table.messages then - local role - if v.role == "assistant" or v.role == "CHATBOT" then - role = "CHATBOT" - else - role = "USER" - end - - chat_history[i] = { - role = role, - message = v.content, - } +local function merge_fields(request_table, model) + model.options = model.options or {} + request_table.temperature = request_table.temperature or model.options.temperature + request_table.max_tokens = request_table.max_tokens or model.options.max_tokens + request_table.truncate = request_table.truncate or "END" + request_table.return_likelihoods = request_table.return_likelihoods or "NONE" + request_table.p = request_table.top_p or model.options.top_p + request_table.k = request_table.top_k or model.options.top_k + + return request_table +end + +local function handle_all(request_table, model) + request_table.model = model.name or request_table.model + request_table.stream = request_table.stream or false -- explicitly set this + + if request_table.prompt and request_table.messages then + return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema") + + elseif request_table.messages then + -- we have to move all BUT THE LAST message into "chat_history" array + -- and move the LAST message (from 'user') into "message" string + if #request_table.messages > 1 then + local chat_history = table_new(#request_table.messages - 1, 0) + for i, v in ipairs(request_table.messages) do + -- if this is the last message prompt, don't add to history + if i < #request_table.messages then + local role + if v.role == "assistant" or v.role == "CHATBOT" then + role = "CHATBOT" + else + role = "USER" end + + chat_history[i] = { + role = role, + message = v.content, + } end - - request_table.chat_history = chat_history end - request_table.temperature = model.options.temperature - request_table.message = request_table.messages[#request_table.messages].content - request_table.messages = nil - - elseif request_table.prompt then - request_table.temperature = model.options.temperature - request_table.max_tokens = model.options.max_tokens - request_table.truncate = request_table.truncate or "END" - request_table.return_likelihoods = request_table.return_likelihoods or "NONE" - request_table.p = model.options.top_p - request_table.k = model.options.top_k - + request_table.chat_history = chat_history end - return request_table, "application/json", nil - end, + request_table.message = request_table.messages[#request_table.messages].content + request_table.messages = nil + request_table = merge_fields(request_table, model) - ["llm/v1/completions"] = function(request_table, model) - request_table.model = model.name - - if request_table.prompt and request_table.messages then - return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema") - - elseif request_table.messages then - -- we have to move all BUT THE LAST message into "chat_history" array - -- and move the LAST message (from 'user') into "message" string - if #request_table.messages > 1 then - local chat_history = table_new(#request_table.messages - 1, 0) - for i, v in ipairs(request_table.messages) do - -- if this is the last message prompt, don't add to history - if i < #request_table.messages then - local role - if v.role == "assistant" or v.role == "CHATBOT" then - role = "CHATBOT" - else - role = "USER" - end - - chat_history[i] = { - role = role, - message = v.content, - } - end - end + elseif request_table.prompt then + request_table.prompt = request_table.prompt + request_table.messages = nil + request_table.message = nil + request_table = merge_fields(request_table, model) - request_table.chat_history = chat_history - end - - request_table.temperature = model.options.temperature - request_table.message = request_table.messages[#request_table.messages].content - request_table.messages = nil - - elseif request_table.prompt then - request_table.temperature = model.options.temperature - request_table.max_tokens = model.options.max_tokens - request_table.truncate = request_table.truncate or "END" - request_table.return_likelihoods = request_table.return_likelihoods or "NONE" - request_table.p = model.options.top_p - request_table.k = model.options.top_k + end - end + return request_table, "application/json", nil +end - return request_table, "application/json", nil - end, +local transformers_to = { + ["llm/v1/chat"] = handle_all, + ["llm/v1/completions"] = handle_all, } local transformers_from = { diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 27472be5c9a..79a3b79da60 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -11,16 +11,8 @@ local socket_url = require "socket.url" local DRIVER_NAME = "openai" -- -local function handle_stream_event(event_string) - if #event_string > 0 then - local lbl, val = event_string:match("(%w*): (.*)") - - if lbl == "data" then - return val - end - end - - return nil +local function handle_stream_event(event_t) + return event_t.data end local transformers_to = { diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 041062a724d..d7278ca3d4e 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -8,6 +8,16 @@ local os = os local parse_url = require("socket.url").parse -- +-- static +local str_find = string.find +local str_sub = string.sub +local tbl_insert = table.insert + +local function str_ltrim(s) -- remove leading whitespace from string. + return (s:gsub("^%s*", "")) +end +-- + local log_entry_keys = { TOKENS_CONTAINER = "usage", META_CONTAINER = "meta", @@ -106,7 +116,7 @@ _M.clear_response_headers = { local function handle_stream_event(event_table, model_info, route_type) if event_table.done then -- return analytics table - return nil, nil, { + return "[DONE]", nil, { prompt_tokens = event_table.prompt_eval_count or 0, completion_tokens = event_table.eval_count or 0, } @@ -143,6 +153,57 @@ local function handle_stream_event(event_table, model_info, route_type) end end +local function complex_split(str, delimiter) + local result = {} + local from = 1 + local delim_from, delim_to = string.find(str, delimiter, from) + while delim_from do + table.insert( result, string.sub(str, from , delim_from-1)) + from = delim_to + 1 + delim_from, delim_to = string.find(str, delimiter, from) + end + table.insert( result, string.sub(str, from)) + return result +end + +function _M.frame_to_events(frame) + local events = {} + + -- todo check if it's raw json and + -- just return the split up data frame + if string.sub(str_ltrim(frame), 1, 1) == "{" then + for event in frame:gmatch("[^\r\n]+") do + events[#events + 1] = { + data = event, + } + end + else + local event_lines = complex_split(frame, "\n") + local struct = { event = nil, id = nil, data = nil } + + for _, dat in ipairs(event_lines) do + if #dat < 1 then + events[#events + 1] = struct + struct = { event = nil, id = nil, data = nil } + end + + local s1, _ = str_find(dat, ":") -- find where the cut point is + + if s1 and s1 ~= 1 then + local field = str_sub(dat, 1, s1-1) -- returns "data " from data: hello world + local value = str_ltrim(str_sub(dat, s1+1)) -- returns "hello world" from data: hello world + + -- for now not checking if the value is already been set + if field == "event" then struct.event = value + elseif field == "id" then struct.id = value + elseif field == "data" then struct.data = value + end -- if + end -- if + end + end + + return events +end function _M.to_ollama(request_table, model) local input = {} @@ -234,8 +295,13 @@ function _M.from_ollama(response_string, model_info, route_type) end end + + if output and output ~= "[DONE]" then + output, err = cjson.encode(output) + end - return output and cjson.encode(output) or nil, nil, analytics + -- err maybe be nil from successful decode above + return output, err, analytics end function _M.pre_request(conf, request_table) @@ -396,4 +462,88 @@ function _M.http_request(url, body, method, headers, http_opts, buffered) end end +local function get_token_text(event_t) + -- chat + return + event_t and + event_t.choices and + #event_t.choices > 0 and + event_t.choices[1].delta and + event_t.choices[1].delta.content + + or + + -- completions + event_t and + event_t.choices and + #event_t.choices > 0 and + event_t.choices[1].text + + or "" +end + +-- Function to count the number of words in a string +local function count_words(str) + local count = 0 + for word in str:gmatch("%S+") do + count = count + 1 + end + return count +end + +-- Function to count the number of words or tokens based on the content type +local function count_prompt(content, tokens_factor) + local count = 0 + + if type(content) == "string" then + count = count_words(content) * tokens_factor + elseif type(content) == "table" then + for _, item in ipairs(content) do + if type(item) == "string" then + count = count + (count_words(item) * tokens_factor) + elseif type(item) == "number" then + count = count + 1 + elseif type(item) == "table" then + for _2, item2 in ipairs(item) do + if type(item2) == "number" then + count = count + 1 + else + return nil, "Invalid request format" + end + end + else + return nil, "Invalid request format" + end + end + else + return nil, "Invalid request format" + end + return count +end + +function _M.calculate_cost(query_body, tokens_models, tokens_factor) + local query_cost = 0 + local err + + if query_body.messages then + -- Calculate the cost based on the content type + for _, message in ipairs(query_body.messages) do + query_cost = query_cost + (count_words(message.content) * tokens_factor) + end + elseif query_body.prompt then + -- Calculate the cost based on the content type + query_cost, err = count_prompt(query_body.prompt, tokens_factor) + if err then + return nil, err + end + else + return nil, "No messages or prompt in query" + end + + -- Round the total cost quantified + query_cost = math.floor(query_cost + 0.5) + + return query_cost +end + return _M diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 5bc54531de5..e723d5faf11 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -291,103 +291,6 @@ local function identify_request(request) end end -local function get_token_text(event_t) - -- chat - return - event_t and - event_t.choices and - #event_t.choices > 0 and - event_t.choices[1].delta and - event_t.choices[1].delta.content - - or - - -- completions - event_t and - event_t.choices and - #event_t.choices > 0 and - event_t.choices[1].text - - or "" -end - --- Function to count the number of words in a string -local function count_words(str) - local count = 0 - for word in str:gmatch("%S+") do - count = count + 1 - end - return count -end - --- Function to count the number of words or tokens based on the content type -local function count_prompt(content, tokens_factor) - local count = 0 - - if type(content) == "string" then - count = count_words(content) * tokens_factor - elseif type(content) == "table" then - for _, item in ipairs(content) do - if type(item) == "string" then - count = count + (count_words(item) * tokens_factor) - elseif type(item) == "number" then - count = count + 1 - elseif type(item) == "table" then - for _2, item2 in ipairs(item) do - if type(item2) == "number" then - count = count + 1 - else - return nil, "Invalid request format" - end - end - else - return nil, "Invalid request format" - end - end - else - return nil, "Invalid request format" - end - return count -end - -function _M:calculate_cost(query_body, tokens_models, tokens_factor) - local query_cost = 0 - local err - - -- Check if max_tokens is provided in the request body - local max_tokens = query_body.max_tokens - - if not max_tokens then - if query_body.model and tokens_models then - max_tokens = tonumber(tokens_models[query_body.model]) - end - end - - if not max_tokens then - return nil, "No max_tokens in query and no key found in the plugin config for model: " .. query_body.model - end - - if query_body.messages then - -- Calculate the cost based on the content type - for _, message in ipairs(query_body.messages) do - query_cost = query_cost + (count_words(message.content) * tokens_factor) - end - elseif query_body.prompt then - -- Calculate the cost based on the content type - query_cost, err = count_prompt(query_body.prompt, tokens_factor) - if err then - return nil, err - end - else - return nil, "No messages or prompt in query" - end - - -- Round the total cost quantified - query_cost = math.floor(query_cost + 0.5) - - return query_cost -end - function _M.is_compatible(request, route_type) local format, err = identify_request(request) if err then @@ -401,160 +304,6 @@ function _M.is_compatible(request, route_type) return false, fmt("[%s] message format is not compatible with [%s] route type", format, route_type) end -function _M:handle_streaming_request(body) - -- convert it to the specified driver format - local request, _, err = self.driver.to_format(body, self.conf.model, self.conf.route_type) - if err then - return internal_server_error(err) - end - - -- run the shared logging/analytics/auth function - ai_shared.pre_request(self.conf, request) - - local prompt_tokens = 0 - local err - if not ai_shared.streaming_has_token_counts[self.conf.model.provider] then - -- Estimate the cost using KONG CX's counter implementation - prompt_tokens, err = self:calculate_cost(request, {}, 1.8) - if err then - return internal_server_error("unable to estimate request token cost: " .. err) - end - end - - -- send it to the ai service - local res, _, err, httpc = self.driver.subrequest(request, self.conf, self.http_opts, true) - if err then - return internal_server_error("failed to connect to " .. self.conf.model.provider .. " for streaming: " .. err) - end - if res.status ~= 200 then - err = "bad status code whilst opening streaming to " .. self.conf.model.provider .. ": " .. res.status - ngx.log(ngx.WARN, err) - return bad_request(err) - end - - -- get a big enough buffered ready to make sure we rip the entire chunk(s) each time - local reader = res.body_reader - local buffer_size = 35536 - local events - - -- we create a fake "kong response" table to pass to the telemetry handler later - local telemetry_fake_table = { - response = buf:new(), - usage = { - prompt_tokens = prompt_tokens, - completion_tokens = 0, - total_tokens = 0, - }, - } - - ngx.status = 200 - ngx.header["Content-Type"] = "text/event-stream" - ngx.header["Via"] = meta._SERVER_TOKENS - - for k, v in pairs(res.headers) do - if not streaming_skip_headers[lower(k)] then - ngx.header[k] = v - end - end - - -- server-sent events should ALWAYS be chunk encoded. - -- if they aren't then... we just won't support them. - repeat - -- receive next chunk - local buffer, err = reader(buffer_size) - if err then - ngx.log(ngx.ERR, "failed to read chunk of streaming buffer, ", err) - break - elseif not buffer then - break - end - - -- we need to rip each message from this chunk - events = {} - for s in buffer:gmatch("[^\r\n]+") do - table.insert(events, s) - end - - local metadata - local route_type = "stream/" .. self.conf.route_type - - -- then parse each into the standard inference format - for i, event in ipairs(events) do - local event_t - local token_t - - -- some LLMs do a final reply with token counts, and such - -- so we will stash them if supported - local formatted, err, this_metadata = self.driver.from_format(event, self.conf.model, route_type) - if err then - return internal_server_error(err) - end - - metadata = this_metadata or metadata - - -- handle event telemetry - if self.conf.logging.log_statistics then - - if not ai_shared.streaming_has_token_counts[self.conf.model.provider] then - event_t = cjson.decode(formatted) - token_t = get_token_text(event_t) - - -- incredibly loose estimate based on https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them - -- but this is all we can do until OpenAI fixes this... - -- - -- essentially, every 4 characters is a token, with minimum of 1 per event - telemetry_fake_table.usage.completion_tokens = - telemetry_fake_table.usage.completion_tokens + math.ceil(#strip(token_t) / 4) - - elseif metadata then - telemetry_fake_table.usage.completion_tokens = metadata.completion_tokens - telemetry_fake_table.usage.prompt_tokens = metadata.prompt_tokens - end - - end - - -- then stream to the client - if formatted then -- only stream relevant frames back to the user - if self.conf.logging.log_payloads then - -- append the "choice" to the buffer, for logging later. this actually works! - if not event_t then - event_t, err = cjson.decode(formatted) - end - - if err then - return internal_server_error("something wrong with decoding a specific token") - end - - if not token_t then - token_t = get_token_text(event_t) - end - - telemetry_fake_table.response:put(token_t) - end - - -- construct, transmit, and flush the frame - ngx.print("data: ", formatted, "\n\n") - ngx.flush(true) - end - end - - until not buffer - - local ok, err = httpc:set_keepalive() - if not ok then - -- continue even if keepalive gets killed - ngx.log(ngx.WARN, "setting keepalive failed: ", err) - end - - -- process telemetry - telemetry_fake_table.response = telemetry_fake_table.response:tostring() - - telemetry_fake_table.usage.total_tokens = telemetry_fake_table.usage.completion_tokens + - telemetry_fake_table.usage.prompt_tokens - - ai_shared.post_request(self.conf, telemetry_fake_table) -end - function _M:ai_introspect_body(request, system_prompt, http_opts, response_regex_match) local err, _ diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index b5886683fcc..aaa390bf586 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -7,6 +7,8 @@ local llm = require("kong.llm") local cjson = require("cjson.safe") local kong_utils = require("kong.tools.gzip") local kong_meta = require("kong.meta") +local buffer = require "string.buffer" +local strip = require("kong.tools.utils").strip -- @@ -33,6 +35,134 @@ local function internal_server_error(msg) return kong.response.exit(500, ERROR_MSG) end +local function get_token_text(event_t) + -- chat + return + event_t and + event_t.choices and + #event_t.choices > 0 and + event_t.choices[1].delta and + event_t.choices[1].delta.content + + or + + -- completions + event_t and + event_t.choices and + #event_t.choices > 0 and + event_t.choices[1].text + + or "" +end + +local function handle_streaming_frame(conf) + -- make a re-usable framebuffer + local framebuffer = buffer.new() + local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" + + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) + + -- create a buffer to store each response token/frame, on first pass + if conf.logging + and conf.logging.log_payloads + and (not kong.ctx.plugin.ai_stream_log_buffer) then + kong.ctx.plugin.ai_stream_log_buffer = buffer.new() + end + + -- now handle each chunk/frame + local chunk = ngx.arg[1] + local finished = ngx.arg[2] + + if type(chunk) == "string" and chunk ~= "" then + -- transform each one into flat format, skipping transformer errors + -- because we have already 200 OK'd the client by now + + if (not finished) and (is_gzip) then + event = kong_utils.inflate_gzip(chunk) + end + + local events = ai_shared.frame_to_events(chunk) + + for _, event in ipairs(events) do + local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type) + + local event_t, token_t, err + + if formatted then -- only stream relevant frames back to the user + if conf.logging and conf.logging.log_payloads and (formatted ~= "[DONE]") then + -- append the "choice" to the buffer, for logging later. this actually works! + if not event_t then + event_t, err = cjson.decode(formatted) + end + + if not err then + if not token_t then + token_t = get_token_text(event_t) + end + + kong.ctx.plugin.ai_stream_log_buffer:put(token_t) + end + end + + -- handle event telemetry + if conf.logging and conf.logging.log_statistics then + if not ai_shared.streaming_has_token_counts[conf.model.provider] then + if formatted ~= "[DONE]" then + if not event_t then + event_t, err = cjson.decode(formatted) + end + + if not err then + if not token_t then + token_t = get_token_text(event_t) + end + + -- incredibly loose estimate based on https://help.openai.com/en/articles/4936856-what-are-tokens-and-how-to-count-them + -- but this is all we can do until OpenAI fixes this... + -- + -- essentially, every 4 characters is a token, with minimum of 1*4 per event + kong.ctx.plugin.ai_stream_completion_tokens = + (kong.ctx.plugin.ai_stream_completion_tokens or 0) + math.ceil(#strip(token_t) / 4) + end + end + + elseif metadata then + kong.ctx.plugin.ai_stream_completion_tokens = metadata.completion_tokens or kong.ctx.plugin.ai_stream_completion_tokens + kong.ctx.plugin.ai_stream_prompt_tokens = metadata.prompt_tokens or kong.ctx.plugin.ai_stream_prompt_tokens + end + end + + framebuffer:put("data: ") + framebuffer:put(formatted or "") + framebuffer:put((formatted ~= "[DONE]") and "\n\n" or "") + end + end + end + + local response_frame = framebuffer:get() + framebuffer = nil + if (not finished) and (is_gzip) then + response_frame = kong_utils.deflate_gzip(response_frame) + end + + ngx.arg[1] = response_frame + + if finished then + local fake_response_t = { + response = kong.ctx.plugin.ai_stream_log_buffer:get(), + usage = { + prompt_tokens = kong.ctx.plugin.ai_stream_prompt_tokens or 0, + completion_tokens = kong.ctx.plugin.ai_stream_completion_tokens or 0, + total_tokens = (kong.ctx.plugin.ai_stream_prompt_tokens or 0) + + (kong.ctx.plugin.ai_stream_completion_tokens or 0), + } + } + + ngx.arg[1] = nil + ai_shared.post_request(conf, fake_response_t) + kong.ctx.plugin.ai_stream_log_buffer = nil + end +end function _M:header_filter(conf) if kong.ctx.shared.skip_response_transformer then @@ -49,6 +179,13 @@ function _M:header_filter(conf) return end + -- we use openai's streaming mode (SSE) + if kong.ctx.shared.ai_proxy_streaming_mode then + -- we are going to send plaintext event-stream frames for ALL models + kong.response.set_header("Content-Type", "text/event-stream") + return + end + local response_body = kong.service.response.get_raw_body() if not response_body then return @@ -57,24 +194,29 @@ function _M:header_filter(conf) local ai_driver = require("kong.llm.drivers." .. conf.model.provider) local route_type = conf.route_type - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - response_body = kong_utils.inflate_gzip(response_body) - end + -- if this is a 'streaming' request, we can't know the final + -- result of the response body, so we just proceed to body_filter + -- to translate each SSE event frame + if not kong.ctx.shared.ai_proxy_streaming_mode then + local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" + if is_gzip then + response_body = kong_utils.inflate_gzip(response_body) + end - local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) - if err then - kong.ctx.plugin.ai_parser_error = true + local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) + if err then + kong.ctx.plugin.ai_parser_error = true - ngx.status = 500 - ERROR_MSG.error.message = err + ngx.status = 500 + ERROR_MSG.error.message = err - kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG) + kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG) - elseif new_response_string then - -- preserve the same response content type; assume the from_format function - -- has returned the body in the appropriate response output format - kong.ctx.plugin.parsed_response = new_response_string + elseif new_response_string then + -- preserve the same response content type; assume the from_format function + -- has returned the body in the appropriate response output format + kong.ctx.plugin.parsed_response = new_response_string + end end ai_driver.post_request(conf) @@ -83,7 +225,7 @@ end function _M:body_filter(conf) -- if body_filter is called twice, then return - if kong.ctx.plugin.body_called then + if kong.ctx.plugin.body_called and (not kong.ctx.shared.ai_proxy_streaming_mode) then return end @@ -119,31 +261,33 @@ function _M:body_filter(conf) if (kong.response.get_status() ~= 200) and (not kong.ctx.plugin.ai_parser_error) then return end - - -- (kong.response.get_status() == 200) or (kong.ctx.plugin.ai_parser_error) - - -- all errors MUST be checked and returned in header_filter - -- we should receive a replacement response body from the same thread - - local original_request = kong.ctx.plugin.parsed_response - local deflated_request = original_request - - if deflated_request then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - deflated_request = kong_utils.deflate_gzip(deflated_request) + + if kong.ctx.shared.ai_proxy_streaming_mode then + handle_streaming_frame(conf) + else + -- all errors MUST be checked and returned in header_filter + -- we should receive a replacement response body from the same thread + + local original_request = kong.ctx.plugin.parsed_response + local deflated_request = original_request + + if deflated_request then + local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" + if is_gzip then + deflated_request = kong_utils.deflate_gzip(deflated_request) + end + + kong.response.set_raw_body(deflated_request) end - kong.response.set_raw_body(deflated_request) - end + -- call with replacement body, or original body if nothing changed + local _, err = ai_shared.post_request(conf, original_request) + if err then + kong.log.warn("analytics phase failed for request, ", err) + end - -- call with replacement body, or original body if nothing changed - local _, err = ai_shared.post_request(conf, original_request) - if err then - kong.log.warn("analytics phase failed for request, ", err) - end + end end - kong.ctx.plugin.body_called = true end @@ -166,6 +310,8 @@ function _M:access(conf) request_table = kong.request.get_body(content_type) + -- TODO octet stream check here + if not request_table then return bad_request("content-type header does not match request body") end @@ -178,50 +324,63 @@ function _M:access(conf) return bad_request(err) end - if request_table.stream or conf.model.options.response_streaming == "always" then - kong.ctx.shared.skip_response_transformer = true - - -- into sub-request streaming handler - -- everything happens in the access phase here - if conf.model.options.response_streaming == "deny" then + -- check if the user has asked for a stream, and/or if + -- we are forcing all requests to be of streaming type + if request_table.stream or + (conf.model.options and conf.model.options.response_streaming) == "always" then + -- this condition will only check if user has tried + -- to activate streaming mode within their request + if conf.model.options and conf.model.options.response_streaming == "deny" then return bad_request("response streaming is not enabled for this LLM") end - local llm_handler = ai_module:new(conf, {}) - llm_handler:handle_streaming_request(request_table) + -- store token cost estimate, on first pass + if not kong.ctx.plugin.ai_stream_prompt_tokens then + local prompt_tokens, err = ai_shared.calculate_cost(request_table or {}, {}, 1.8) + if err then + return internal_server_error("unable to estimate request token cost: " .. err) + end + + kong.ctx.plugin.ai_stream_prompt_tokens = prompt_tokens + end + + -- specific actions need to skip later for this to work + kong.ctx.shared.ai_proxy_streaming_mode = true + else kong.service.request.enable_buffering() + end - local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - - -- execute pre-request hooks for this driver - local ok, err = ai_driver.pre_request(conf, request_table) - if not ok then - return bad_request(err) - end + local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - -- transform the body to Kong-format for this provider/model - local parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type) - if err then - return bad_request(err) - end + -- execute pre-request hooks for this driver + local ok, err = ai_driver.pre_request(conf, request_table) + if not ok then + return bad_request(err) + end - -- execute pre-request hooks for "all" drivers before set new body - local ok, err = ai_shared.pre_request(conf, parsed_request_body) - if not ok then - return bad_request(err) - end + -- transform the body to Kong-format for this provider/model + local parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type) + if err then + return bad_request(err) + end - kong.service.request.set_body(parsed_request_body, content_type) + -- execute pre-request hooks for "all" drivers before set new body + local ok, err = ai_shared.pre_request(conf, parsed_request_body) + if not ok then + return bad_request(err) + end - -- now re-configure the request for this operation type - local ok, err = ai_driver.configure_request(conf) - if not ok then - return internal_server_error(err) - end + kong.service.request.set_body(parsed_request_body, content_type) - -- lights out, and away we go + -- now re-configure the request for this operation type + local ok, err = ai_driver.configure_request(conf) + if not ok then + return internal_server_error(err) end + + -- lights out, and away we go + end diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 7773ee6c71d..a634230a81b 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -4,6 +4,7 @@ local pl_replace = require("pl.stringx").replace local cjson = require("cjson.safe") local fmt = string.format local llm = require("kong.llm") +local ai_shared = require("kong.llm.drivers.shared") local SAMPLE_LLM_V1_CHAT = { messages = { @@ -366,12 +367,20 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- what we do is first put the SAME request message from the user, through the converter, for this provider/format it("converts to provider request format correctly", function() + -- load the real provider frame from file local real_stream_frame = pl_file.read(fmt("spec/fixtures/ai-proxy/unit/real-stream-frames/%s/%s.txt", config.provider, pl_replace(format_name, "/", "-"))) - local real_transformed_frame, err = driver.from_format(real_stream_frame, config, "stream/" .. format_name) + + -- use the shared function to produce an SSE format object + local real_transformed_frame, err = ai_shared.frame_to_events(real_stream_frame) + assert.is_nil(err) + -- transform the SSE frame into OpenAI format + real_transformed_frame, err = driver.from_format(real_transformed_frame[1], config, "stream/" .. format_name) + assert.is_nil(err) + real_transformed_frame, err = cjson.decode(real_transformed_frame) assert.is_nil(err) - real_transformed_frame = cjson.decode(real_transformed_frame) + -- check it's what we expeced assert.same(expected_stream_choices[format_name], real_transformed_frame.choices) end) diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index 089dbbb671b..cef707bbecf 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -53,18 +53,18 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then [5] = 'data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', [6] = 'data: [DONE]', } - + local fmt = string.format local pl_file = require "pl.file" local json = require("cjson.safe") - + ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + local token = ngx.req.get_headers()["authorization"] local token_query = ngx.req.get_uri_args()["apikey"] - + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() @@ -75,11 +75,10 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) else -- GOOD RESPONSE - + ngx.status = 200 ngx.header["Content-Type"] = "text/event-stream" - ngx.header["Transfer-Encoding"] = "chunked" - + for i, EVENT in ipairs(_EVENT_CHUNKS) do ngx.print(fmt("%s\n\n", EVENT)) end @@ -133,7 +132,6 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.status = 200 ngx.header["Content-Type"] = "text/event-stream" - ngx.header["Transfer-Encoding"] = "chunked" for i, EVENT in ipairs(_EVENT_CHUNKS) do ngx.print(fmt("%s\n\n", EVENT)) @@ -146,6 +144,73 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/anthropic/llm/v1/chat/good" { + content_by_lua_block { + local _EVENT_CHUNKS = { + [1] = 'event: message_start', + [2] = 'event: content_block_start', + [3] = 'event: ping', + [4] = 'event: content_block_delta', + [5] = 'event: content_block_delta', + [6] = 'event: content_block_delta', + [7] = 'event: content_block_delta', + [8] = 'event: content_block_delta', + [9] = 'event: content_block_stop', + [10] = 'event: message_delta', + [11] = 'event: message_stop', + } + + local _DATA_CHUNKS = { + [1] = 'data: {"type":"message_start","message":{"id":"msg_013NVLwA2ypoPDJAxqC3G7wg","type":"message","role":"assistant","model":"claude-2.1","stop_sequence":null,"usage":{"input_tokens":15,"output_tokens":1},"content":[],"stop_reason":null} }', + [2] = 'data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} }', + [3] = 'data: {"type": "ping"}', + [4] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"1"} }', + [5] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" +"} }', + [6] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" 1"} }', + [7] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" ="} }', + [8] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" 2"} }', + [9] = 'data: {"type":"content_block_stop","index":0 }', + [10] = '{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":9}}', + [11] = '{"type":"message_stop"}', + } + + local fmt = string.format + local pl_file = require "pl.file" + local json = require("cjson.safe") + + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + local token = ngx.req.get_headers()["api-key"] + local token_query = ngx.req.get_uri_args()["apikey"] + + if token == "anthropic-key" or token_query == "anthropic-key" or body.apikey == "anthropic-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + + if err or (body.messages == ngx.null) then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + else + -- GOOD RESPONSE + + ngx.status = 200 + ngx.header["Content-Type"] = "text/event-stream" + + for i, EVENT in ipairs(_EVENT_CHUNKS) do + ngx.print(fmt("%s\n", EVENT)) + ngx.print(fmt("%s\n\n", _DATA_CHUNKS[i])) + end + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json")) + end + } + } + location = "/openai/llm/v1/chat/bad" { content_by_lua_block { local fmt = string.format @@ -262,6 +327,43 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } -- + -- 200 chat anthropic + local anthropic_chat_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/anthropic/llm/v1/chat/good" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = anthropic_chat_good.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "anthropic-key", + }, + model = { + name = "claude-2.1", + provider = "anthropic", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/anthropic/llm/v1/chat/good", + anthropic_version = "2023-06-01", + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = anthropic_chat_good.id }, + config = { + path = "/dev/stdout", + }, + } + -- + -- 400 chat openai local openai_chat_bad = assert(bp.routes:insert { service = empty_service, @@ -333,8 +435,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then port = helpers.get_proxy_port(), }) if not ok then - ngx.log(ngx.ERR, "connection failed: ", err) - return + assert.is_nil(err) end -- Then send using `request`, supplying a path and `Host` header instead of a @@ -348,8 +449,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }) if not res then - ngx.log(ngx.ERR, "request failed: ", err) - return + assert.is_nil(err) end local reader = res.body_reader @@ -362,8 +462,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- receive next chunk local buffer, err = reader(buffer_size) if err then - ngx.log(ngx.ERR, err) - break + assert.is_falsy(err and err ~= "closed") end if buffer then @@ -399,8 +498,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then port = helpers.get_proxy_port(), }) if not ok then - ngx.log(ngx.ERR, "connection failed: ", err) - return + assert.is_nil(err) end -- Then send using `request`, supplying a path and `Host` header instead of a @@ -414,8 +512,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }) if not res then - ngx.log(ngx.ERR, "request failed: ", err) - return + assert.is_nil(err) end local reader = res.body_reader @@ -428,8 +525,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- receive next chunk local buffer, err = reader(buffer_size) if err then - ngx.log(ngx.ERR, err) - break + assert.is_falsy(err and err ~= "closed") end if buffer then @@ -451,10 +547,72 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end until not buffer - assert.equal(#events, 16) + assert.equal(#events, 17) assert.equal(buf:tostring(), "1 + 1 = 2. This is the most basic example of addition.") end) + it("good stream request anthropic", function() + local httpc = http.new() + + local ok, err, _ = httpc:connect({ + scheme = "http", + host = helpers.mock_upstream_host, + port = helpers.get_proxy_port(), + }) + if not ok then + assert.is_nil(err) + end + + -- Then send using `request`, supplying a path and `Host` header instead of a + -- full URI. + local res, err = httpc:request({ + path = "/anthropic/llm/v1/chat/good", + body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good-stream.json"), + headers = { + ["content-type"] = "application/json", + ["accept"] = "application/json", + }, + }) + if not res then + assert.is_nil(err) + end + + local reader = res.body_reader + local buffer_size = 35536 + local events = {} + local buf = require("string.buffer").new() + + -- extract event + repeat + -- receive next chunk + local buffer, err = reader(buffer_size) + if err then + assert.is_falsy(err and err ~= "closed") + end + + if buffer then + -- we need to rip each message from this chunk + for s in buffer:gmatch("[^\r\n]+") do + local s_copy = s + s_copy = string.sub(s_copy,7) + s_copy = cjson.decode(s_copy) + + buf:put(s_copy + and s_copy.choices + and s_copy.choices + and s_copy.choices[1] + and s_copy.choices[1].delta + and s_copy.choices[1].delta.content + or "") + table.insert(events, s) + end + end + until not buffer + + assert.equal(#events, 7) + assert.equal(buf:tostring(), "1 + 1 = 2") + end) + it("bad request is returned to the client not-streamed", function() local httpc = http.new() @@ -464,8 +622,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then port = helpers.get_proxy_port(), }) if not ok then - ngx.log(ngx.ERR, "connection failed: ", err) - return + assert.is_nil(err) end -- Then send using `request`, supplying a path and `Host` header instead of a @@ -479,8 +636,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then }, }) if not res then - ngx.log(ngx.ERR, "request failed: ", err) - return + assert.is_nil(err) end local reader = res.body_reader @@ -492,8 +648,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- receive next chunk local buffer, err = reader(buffer_size) if err then - ngx.log(ngx.ERR, err) - break + assert.is_nil(err) end if buffer then diff --git a/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good-stream.json b/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good-stream.json new file mode 100644 index 00000000000..c05edd15b8a --- /dev/null +++ b/spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good-stream.json @@ -0,0 +1,13 @@ +{ + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "What is 1 + 1?" + } + ], + "stream": true +} \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-chat.json index 3ca2686ce8b..00756b5901d 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-chat.json @@ -24,5 +24,6 @@ ], "system": "You are a mathematician.", "max_tokens": 512, - "temperature": 0.5 + "temperature": 0.5, + "stream": false } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-completions.json index e43ab0c2e63..7af2711f65f 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/anthropic/llm-v1-completions.json @@ -2,5 +2,6 @@ "model": "claude-2.1", "prompt": "Human: Explain why you can't divide by zero?\n\nAssistant:", "max_tokens_to_sample": 512, - "temperature": 0.5 + "temperature": 0.5, + "stream": false } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json index 24970854ade..b22c3e520a4 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json @@ -8,5 +8,9 @@ ], "message": "Why can't you divide by zero?", "model": "command", - "temperature": 0.5 + "max_tokens": 512, + "temperature": 0.5, + "truncate": "END", + "return_likelihoods": "NONE", + "stream": false } diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json index a1bbaa8591c..9ba71345bad 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json @@ -6,5 +6,6 @@ "p": 0.75, "k": 5, "return_likelihoods": "NONE", - "truncate": "END" + "truncate": "END", + "stream": false } diff --git a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt index fac4fed43ff..e9e1b313fa1 100644 --- a/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt +++ b/spec/fixtures/ai-proxy/unit/real-stream-frames/openai/llm-v1-completions.txt @@ -1 +1 @@ -data: {"choices": [{"finish_reason": null,"index": 0,"logprobs": null,"text": "the answer"}],"created": 1711938803,"id": "cmpl-991m7YSJWEnzrBqk41In8Xer9RIEB","model": "gpt-3.5-turbo-instruct","object": "text_completion"} \ No newline at end of file +data: {"choices": [{"finish_reason": null,"index": 0,"logprobs": null,"text": "the answer"}],"created": 1711938803,"id": "cmpl-991m7YSJWEnzrBqk41In8Xer9RIEB","model": "gpt-3.5-turbo-instruct","object": "text_completion"} From 2ce45870972c9de8987b3555ba46633539fa84b7 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 22 Apr 2024 15:48:22 +0100 Subject: [PATCH 02/24] feat(ai-proxy): added variable stencil mechanism for routing --- kong/llm/drivers/anthropic.lua | 4 ++++ kong/llm/drivers/azure.lua | 32 ++++++++++++++++++++++++++------ kong/llm/drivers/cohere.lua | 4 ++++ kong/llm/drivers/llama2.lua | 4 ++++ kong/llm/drivers/mistral.lua | 4 ++++ kong/llm/drivers/openai.lua | 4 ++++ kong/tools/http.lua | 23 +++++++++++++++++++++++ 7 files changed, 69 insertions(+), 6 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index ec6a717ef80..0d176d90712 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -6,6 +6,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" local buffer = require("string.buffer") +local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -473,6 +474,9 @@ function _M.configure_request(conf) kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) end + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) + kong.service.request.set_header("anthropic-version", conf.model.options.anthropic_version) local auth_header_name = conf.auth and conf.auth.header_name diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 7918cf166bc..dce37a00b79 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -6,6 +6,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" +local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -14,9 +15,21 @@ local DRIVER_NAME = "azure" _M.from_format = openai_driver.from_format _M.to_format = openai_driver.to_format -_M.pre_request = openai_driver.pre_request _M.header_filter_hooks = openai_driver.header_filter_hooks +function _M.pre_request(conf) + kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli + + -- for azure provider, all of these must/will be set by now + if conf.logging and conf.logging.log_statistics then + kong.log.set_serialize_value("ai.meta.azure_instance_id", conf.model.options.azure_instance) + kong.log.set_serialize_value("ai.meta.azure_deployment_id", conf.model.options.azure_deployment_id) + kong.log.set_serialize_value("ai.meta.azure_api_version", conf.model.options.azure_api_version) + end + + return true +end + function _M.post_request(conf) if ai_shared.clear_response_headers[DRIVER_NAME] then for i, v in ipairs(ai_shared.clear_response_headers[DRIVER_NAME]) do @@ -40,11 +53,12 @@ function _M.subrequest(body, conf, http_opts, return_res_table) end -- azure has non-standard URL format - local url = (conf.model.options and conf.model.options.upstream_url) - or fmt( + local url = fmt( "%s%s?api-version=%s", ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id), - ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, + conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, conf.model.options.azure_api_version or "2023-05-15" ) @@ -91,7 +105,9 @@ function _M.configure_request(conf) local url = fmt( "%s%s", ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id), - ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path ) parsed_url = socket_url.parse(url) end @@ -100,6 +116,8 @@ function _M.configure_request(conf) kong.service.request.set_scheme(parsed_url.scheme) kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value @@ -114,7 +132,9 @@ function _M.configure_request(conf) local query_table = kong.request.get_query() -- technically min supported version - query_table["api-version"] = conf.model.options and conf.model.options.azure_api_version or "2023-05-15" + query_table["api-version"] = kong.request.get_query_arg("api-version") + or (conf.model.options and conf.model.options.azure_api_version) + or "2023-05-15" if auth_param_name and auth_param_value and auth_param_location == "query" then query_table[auth_param_name] = auth_param_value diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 5b519b17807..ce49ff2a424 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -6,6 +6,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" local table_new = require("table.new") +local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -470,6 +471,9 @@ function _M.configure_request(conf) kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) end + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index bf3ee42ee74..4b835a82fc0 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -8,6 +8,7 @@ local ai_shared = require("kong.llm.drivers.shared") local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" local string_gsub = string.gsub +local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -269,6 +270,9 @@ function _M.configure_request(conf) kong.service.request.set_scheme(parsed_url.scheme) kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index d091939eeb2..76c28bd8304 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -7,6 +7,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" +local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -155,6 +156,9 @@ function _M.configure_request(conf) kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) end + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 79a3b79da60..bcfbb9400e5 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -5,6 +5,7 @@ local cjson = require("cjson.safe") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" +local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -228,6 +229,9 @@ function _M.configure_request(conf) kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) end + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/tools/http.lua b/kong/tools/http.lua index 26a125fab76..de4b3fe8a30 100644 --- a/kong/tools/http.lua +++ b/kong/tools/http.lua @@ -553,5 +553,28 @@ do end end +do + local string_sub = string.sub + + --- + -- Ensures that a given path adheres to a valid format + -- for usage with PDK set_path, or a lua-resty-http client. + -- + -- The function returns the re-formatted path, in its valid form, + -- or returns the original string if nothing was changed. + -- + -- @param path string the path to ensure is valid + -- @return string the newly-formatted valid path, or the original path if nothing changed + function _M.ensure_valid_path(path) + if string_sub(path, 1, 1) ~= "/" then + path = "/" .. path + + elseif string_sub(path, 1, 2) == "//" then + path = string_sub(path, 2) + end + + return path + end +end return _M From 8ae0d51e28909e4fd0f41c5bf27fa85a8d72ef63 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 22 Apr 2024 18:31:14 +0100 Subject: [PATCH 03/24] feat(ai-proxy): folded in features from #12807 --- kong/llm/drivers/anthropic.lua | 26 +-- kong/llm/drivers/azure.lua | 7 +- kong/llm/drivers/cohere.lua | 30 ++-- kong/llm/drivers/llama2.lua | 20 +-- kong/llm/drivers/mistral.lua | 19 +- kong/llm/drivers/openai.lua | 103 ++++++----- kong/llm/drivers/shared.lua | 99 ++++++++++- kong/llm/init.lua | 29 ++-- kong/plugins/ai-proxy/handler.lua | 122 +++++++++---- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 164 +++++++++++++++++- .../02-openai_integration_spec.lua | 112 ++++++++++-- .../03-anthropic_integration_spec.lua | 17 -- .../04-cohere_integration_spec.lua | 17 -- .../38-ai-proxy/05-azure_integration_spec.lua | 17 -- .../06-mistral_integration_spec.lua | 19 -- .../azure/llm-v1-completions.json | 3 +- .../llama2/raw/llm-v1-chat.json | 3 +- .../llama2/raw/llm-v1-completions.json | 3 +- 18 files changed, 548 insertions(+), 262 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 0d176d90712..8300db58056 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -457,26 +457,26 @@ end function _M.configure_request(conf) local parsed_url - if conf.route_type ~= "preserve" then - if conf.model.options.upstream_url then - parsed_url = socket_url.parse(conf.model.options.upstream_url) - else - parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + if conf.model.options.upstream_url then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - if not parsed_url.path then - return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type) - end + if not parsed_url.path then + return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type) end - - kong.service.request.set_path(parsed_url.path) - kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) end -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = ensure_valid_path(parsed_url.path) + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + + + kong.service.request.set_header("anthropic-version", conf.model.options.anthropic_version) local auth_header_name = conf.auth and conf.auth.header_name diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index dce37a00b79..a690a2593e2 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -111,13 +111,14 @@ function _M.configure_request(conf) ) parsed_url = socket_url.parse(url) end + + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) - -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index ce49ff2a424..30ca9f61d5d 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -453,27 +453,25 @@ end -- returns err or nil function _M.configure_request(conf) local parsed_url - - if conf.route_type ~= "preserve" then - if conf.model.options.upstream_url then - parsed_url = socket_url.parse(conf.model.options.upstream_url) - else - parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - - if not parsed_url.path then - return false, fmt("operation %s is not supported for cohere provider", conf.route_type) - end + + if conf.model.options.upstream_url then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + + if not parsed_url.path then + return false, fmt("operation %s is not supported for cohere provider", conf.route_type) end - - kong.service.request.set_path(parsed_url.path) - kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = ensure_valid_path(parsed_url.path) + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 4b835a82fc0..f2a951024ff 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -109,10 +109,11 @@ end local function to_raw(request_table, model) local messages = {} messages.parameters = {} - messages.parameters.max_new_tokens = model.options and model.options.max_tokens - messages.parameters.top_p = model.options and model.options.top_p or 1.0 - messages.parameters.top_k = model.options and model.options.top_k or 40 - messages.parameters.temperature = model.options and model.options.temperature + messages.parameters.max_new_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) + messages.parameters.top_p = request_table.top_p or (model.options and model.options.top_p) + messages.parameters.top_k = request_table.top_k or (model.options and model.options.top_k) + messages.parameters.temperature = request_table.temperature or (model.options and model.options.temperature) + messages.parameters.stream = request_table.stream or false -- explicitly set this if request_table.prompt and request_table.messages then return kong.response.exit(400, "cannot run raw 'prompt' and chat history 'messages' requests at the same time - refer to schema") @@ -254,11 +255,6 @@ function _M.post_request(conf) end function _M.pre_request(conf, body) - -- check for user trying to bring own model - if body and body.model then - return false, "cannot use own model for this instance" - end - return true, nil end @@ -266,13 +262,13 @@ end function _M.configure_request(conf) local parsed_url = socket_url.parse(conf.model.options.upstream_url) + -- if the path is read from a URL capture, ensure that it is valid + parsed_url.path = ensure_valid_path(parsed_url.path) + kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) - -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) - local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index 76c28bd8304..fe814a67da2 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -129,11 +129,6 @@ function _M.subrequest(body, conf, http_opts, return_res_table) end function _M.pre_request(conf, body) - -- check for user trying to bring own model - if body and body.model then - return nil, "cannot use own model for this instance" - end - return true, nil end @@ -147,18 +142,16 @@ end -- returns err or nil function _M.configure_request(conf) - if conf.route_type ~= "preserve" then - -- mistral shared openai operation paths - local parsed_url = socket_url.parse(conf.model.options.upstream_url) - - kong.service.request.set_path(parsed_url.path) - kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) - end + -- mistral shared operation paths + local parsed_url = socket_url.parse(conf.model.options.upstream_url) -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = ensure_valid_path(parsed_url.path) + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index bcfbb9400e5..4830a937f25 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -12,44 +12,45 @@ local ensure_valid_path = require("kong.tools.utils").ensure_valid_path local DRIVER_NAME = "openai" -- +-- merge_defaults takes the model options, and sets any defaults defined, +-- if the caller hasn't explicitly set them +-- +-- we have already checked that "max_tokens" isn't overridden when it +-- is not allowed to do so. +local _MERGE_PROPERTIES = { + [1] = "max_tokens", + [2] = "temperature", + [3] = "top_p", + [4] = "top_k", +} + +local function merge_defaults(request, options) + for i, v in ipairs(_MERGE_PROPERTIES) do + request[v] = request[v] or (options and options[v]) or nil + end + + return request +end + local function handle_stream_event(event_t) return event_t.data end local transformers_to = { - ["llm/v1/chat"] = function(request_table, model, max_tokens, temperature, top_p) - -- if user passed a prompt as a chat, transform it to a chat message - if request_table.prompt then - request_table.messages = { - { - role = "user", - content = request_table.prompt, - } - } - end - - local this = { - model = model, - messages = request_table.messages, - max_tokens = max_tokens, - temperature = temperature, - top_p = top_p, - stream = request_table.stream or false, - } - - return this, "application/json", nil + ["llm/v1/chat"] = function(request_table, model_info, route_type) + request_table = merge_defaults(request_table, model_info.options) + request_table.model = request_table.model or model_info.name + request_table.stream = request_table.stream or false -- explicitly set this + + return request_table, "application/json", nil end, - ["llm/v1/completions"] = function(request_table, model, max_tokens, temperature, top_p) - local this = { - prompt = request_table.prompt, - model = model, - max_tokens = max_tokens, - temperature = temperature, - stream = request_table.stream or false, - } + ["llm/v1/completions"] = function(request_table, model_info, route_type) + request_table = merge_defaults(request_table, model_info.options) + request_table.model = model_info.name + request_table.stream = request_table.stream or false -- explicitly set this - return this, "application/json", nil + return request_table, "application/json", nil end, } @@ -119,10 +120,7 @@ function _M.to_format(request_table, model_info, route_type) local ok, response_object, content_type, err = pcall( transformers_to[route_type], request_table, - model_info.name, - (model_info.options and model_info.options.max_tokens), - (model_info.options and model_info.options.temperature), - (model_info.options and model_info.options.top_p) + model_info ) if err or (not ok) then return nil, nil, fmt("error transforming to %s://%s", model_info.provider, route_type) @@ -199,10 +197,7 @@ function _M.post_request(conf) end function _M.pre_request(conf, body) - -- check for user trying to bring own model - if body and body.model then - return nil, "cannot use own model for this instance" - end + kong.service.request.set_header("Accept-Encoding", "gzip, identity") -- tell server not to send brotli return true, nil end @@ -211,27 +206,27 @@ end function _M.configure_request(conf) local parsed_url - if conf.route_type ~= "preserve" then - if (conf.model.options and conf.model.options.upstream_url) then - parsed_url = socket_url.parse(conf.model.options.upstream_url) - else - local path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path - if not path then - return nil, fmt("operation %s is not supported for openai provider", conf.route_type) - end - - parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = path + if (conf.model.options and conf.model.options.upstream_url) then + parsed_url = socket_url.parse(conf.model.options.upstream_url) + else + local path = conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + if not path then + return nil, fmt("operation %s is not supported for openai provider", conf.route_type) end - - kong.service.request.set_path(parsed_url.path) - kong.service.request.set_scheme(parsed_url.scheme) - kong.service.set_target(parsed_url.host, tonumber(parsed_url.port)) + + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) + parsed_url.path = path end - + -- if the path is read from a URL capture, ensure that it is valid parsed_url.path = ensure_valid_path(parsed_url.path) + kong.service.request.set_path(parsed_url.path) + kong.service.request.set_scheme(parsed_url.scheme) + kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) + local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index d7278ca3d4e..b5f45853575 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -6,12 +6,15 @@ local http = require("resty.http") local fmt = string.format local os = os local parse_url = require("socket.url").parse +local utils = require("kong.tools.utils") -- -- static -local str_find = string.find -local str_sub = string.sub -local tbl_insert = table.insert +local str_find = string.find +local str_sub = string.sub +local tbl_insert = table.insert +local string_match = string.match +local split = utils.split local function str_ltrim(s) -- remove leading whitespace from string. return (s:gsub("^%s*", "")) @@ -226,10 +229,10 @@ function _M.to_ollama(request_table, model) if model.options then input.options = {} - if model.options.max_tokens then input.options.num_predict = model.options.max_tokens end - if model.options.temperature then input.options.temperature = model.options.temperature end - if model.options.top_p then input.options.top_p = model.options.top_p end - if model.options.top_k then input.options.top_k = model.options.top_k end + input.options.num_predict = request_table.max_tokens or model.options.max_tokens + input.options.temperature = request_table.temperature or model.options.temperature + input.options.top_p = request_table.top_p or model.options.top_p + input.options.top_k = request_table.top_k or model.options.top_k end return input, "application/json", nil @@ -304,16 +307,96 @@ function _M.from_ollama(response_string, model_info, route_type) return output, err, analytics end +function _M.conf_from_request(kong_request, source, key) + if source == "uri_captures" then + return kong_request.get_uri_captures().named[key] + elseif source == "headers" then + return kong_request.get_header(key) + elseif source == "query_params" then + return kong_request.get_query_arg(key) + else + return nil, "source '" .. source .. "' is not supported" + end +end + +function _M.conf_from_request(kong_request, source, key) + if source == "uri_captures" then + return kong_request.get_uri_captures().named[key] + elseif source == "headers" then + return kong_request.get_header(key) + elseif source == "query_params" then + return kong_request.get_query_arg(key) + else + return nil, "source '" .. source .. "' is not supported" + end +end + +function _M.resolve_plugin_conf(kong_request, conf) + local err + local conf_m = utils.cycle_aware_deep_copy(conf) + + -- handle model name + local model_m = string_match(conf_m.model.name or "", '%$%((.-)%)') + if model_m then + local splitted = split(model_m, '.') + if #splitted ~= 2 then + return nil, "cannot parse expression for field 'model.name'" + end + + -- find the request parameter, with the configured name + model_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) + if err then + return nil, err + end + if not model_m then + return nil, "'" .. splitted[1] .. "', key '" .. splitted[2] .. "' was not provided" + end + + -- replace the value + conf_m.model.name = model_m + end + + -- handle all other options + for k, v in pairs(conf.model.options or {}) do + local prop_m = string_match(v or "", '%$%((.-)%)') + if prop_m then + local splitted = split(prop_m, '.') + if #splitted ~= 2 then + return nil, "cannot parse expression for field '" .. v .. "'" + end + + -- find the request parameter, with the configured name + prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) + if err then + return nil, err + end + if not prop_m then + return nil, splitted[1] .. " key " .. splitted[2] .. " was not provided" + end + + -- replace the value + conf_m.model.options[k] = prop_m + end + end + + return conf_m +end + function _M.pre_request(conf, request_table) -- process form/json body auth information local auth_param_name = conf.auth and conf.auth.param_name local auth_param_value = conf.auth and conf.auth.param_value local auth_param_location = conf.auth and conf.auth.param_location - if auth_param_name and auth_param_value and auth_param_location == "body" then + if auth_param_name and auth_param_value and auth_param_location == "body" and request_table then request_table[auth_param_name] = auth_param_value end + if conf.logging and conf.logging.log_statistics then + kong.log.set_serialize_value(log_entry_keys.REQUEST_MODEL, conf.model.name) + kong.log.set_serialize_value(log_entry_keys.PROVIDER_NAME, conf.model.provider) + end + -- if enabled AND request type is compatible, capture the input for analytics if conf.logging and conf.logging.log_payloads then kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body()) diff --git a/kong/llm/init.lua b/kong/llm/init.lua index e723d5faf11..57452c2d7d9 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -66,20 +66,17 @@ local model_options_schema = { type = "number", description = "Defines the matching temperature, if using chat or completion models.", required = false, - between = { 0.0, 5.0 }, - default = 1.0 }}, + between = { 0.0, 5.0 }}}, { top_p = { type = "number", description = "Defines the top-p probability mass, if supported.", required = false, - between = { 0, 1 }, - default = 1.0 }}, + between = { 0, 1 }}}, { top_k = { type = "integer", description = "Defines the top-k most likely tokens, if supported.", required = false, - between = { 0, 500 }, - default = 0 }}, + between = { 0, 500 }}}, { anthropic_version = { type = "string", description = "Defines the schema/API version, if using Anthropic provider.", @@ -111,6 +108,11 @@ local model_options_schema = { description = "Manually specify or override the full URL to the AI operation endpoints, " .. "when calling (self-)hosted models, or for running via a private endpoint.", required = false }}, + { upstream_path = { + description = "Manually specify or override the AI operation path, " + .. "used when e.g. using the 'preserve' route_type.", + type = "string", + required = false }}, } } @@ -157,9 +159,10 @@ _M.config_schema = { fields = { { route_type = { type = "string", - description = "The model's operation implementation, for this provider.", + description = "The model's operation implementation, for this provider. " .. + "Set to `preserve` to pass through without transformation.", required = true, - one_of = { "llm/v1/chat", "llm/v1/completions" } }}, + one_of = { "llm/v1/chat", "llm/v1/completions", "preserve" } }}, { auth = auth_schema }, { model = model_schema }, { logging = logging_schema }, @@ -184,12 +187,6 @@ _M.config_schema = { then_at_least_one_of = { "model.options.mistral_format" }, then_err = "must set %s for mistral provider" }}, - { conditional_at_least_one_of = { if_field = "model.provider", - if_match = { }, - then_at_least_one_of = { "model.name" }, - then_err = "Must set a model name. Refer to https://docs.konghq.com/hub/kong-inc/ai-proxy/ " .. - "for supported models." }}, - { conditional_at_least_one_of = { if_field = "model.provider", if_match = { one_of = { "anthropic" } }, then_at_least_one_of = { "model.options.anthropic_version" }, @@ -292,6 +289,10 @@ local function identify_request(request) end function _M.is_compatible(request, route_type) + if route_type == "preserve" then + return true + end + local format, err = identify_request(request) if err then return nil, err diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index aaa390bf586..7a5cd20f012 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -203,19 +203,23 @@ function _M:header_filter(conf) response_body = kong_utils.inflate_gzip(response_body) end - local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) - if err then - kong.ctx.plugin.ai_parser_error = true - - ngx.status = 500 - ERROR_MSG.error.message = err - - kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG) - - elseif new_response_string then - -- preserve the same response content type; assume the from_format function - -- has returned the body in the appropriate response output format - kong.ctx.plugin.parsed_response = new_response_string + if route_type == "preserve" then + kong.ctx.plugin.parsed_response = response_body + else + local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) + if err then + kong.ctx.plugin.ai_parser_error = true + + ngx.status = 500 + ERROR_MSG.error.message = err + + kong.ctx.plugin.parsed_response = cjson.encode(ERROR_MSG) + + elseif new_response_string then + -- preserve the same response content type; assume the from_format function + -- has returned the body in the appropriate response output format + kong.ctx.plugin.parsed_response = new_response_string + end end end @@ -229,7 +233,7 @@ function _M:body_filter(conf) return end - if kong.ctx.shared.skip_response_transformer then + if kong.ctx.shared.skip_response_transformer and (route_type ~= "preserve") then local response_body if kong.ctx.shared.parsed_response then response_body = kong.ctx.shared.parsed_response @@ -262,32 +266,33 @@ function _M:body_filter(conf) return end - if kong.ctx.shared.ai_proxy_streaming_mode then - handle_streaming_frame(conf) - else + if route_type ~= "preserve" then + if kong.ctx.shared.ai_proxy_streaming_mode then + handle_streaming_frame(conf) + else -- all errors MUST be checked and returned in header_filter -- we should receive a replacement response body from the same thread + local original_request = kong.ctx.plugin.parsed_response + local deflated_request = original_request + + if deflated_request then + local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" + if is_gzip then + deflated_request = kong_utils.deflate_gzip(deflated_request) + end - local original_request = kong.ctx.plugin.parsed_response - local deflated_request = original_request - - if deflated_request then - local is_gzip = kong.response.get_header("Content-Encoding") == "gzip" - if is_gzip then - deflated_request = kong_utils.deflate_gzip(deflated_request) + kong.response.set_raw_body(deflated_request) end - kong.response.set_raw_body(deflated_request) - end - - -- call with replacement body, or original body if nothing changed - local _, err = ai_shared.post_request(conf, original_request) - if err then - kong.log.warn("analytics phase failed for request, ", err) + -- call with replacement body, or original body if nothing changed + local _, err = ai_shared.post_request(conf, original_request) + if err then + kong.log.warn("analytics phase failed for request, ", err) + end end - end end + kong.ctx.plugin.body_called = true end @@ -298,6 +303,7 @@ function _M:access(conf) kong.ctx.plugin.operation = route_type local request_table + local multipart = false -- we may have received a replacement / decorated request body from another AI plugin if kong.ctx.shared.replacement_request then @@ -313,7 +319,41 @@ function _M:access(conf) -- TODO octet stream check here if not request_table then - return bad_request("content-type header does not match request body") + if not string.find(content_type, "multipart/form-data", nil, true) then + return bad_request("content-type header does not match request body") + end + + multipart = true -- this may be a large file upload, so we have to proxy it directly + end + end + + -- resolve the real plugin config values + local conf_m, err = ai_shared.resolve_plugin_conf(kong.request, conf) + if err then + return bad_request(err) + end + + -- copy from the user request if present + if (not multipart) and (not conf_m.model.name) and (request_table.model) then + conf_m.model.name = request_table.model + elseif multipart then + conf_m.model.name = "NOT_SPECIFIED" + end + + -- model is stashed in the copied plugin conf, for consistency in transformation functions + if not conf_m.model.name then + return bad_request("model parameter not found in request, nor in gateway configuration") + end + + -- stash for analytics later + kong.ctx.plugin.llm_model_requested = conf_m.model.name + + -- check the incoming format is the same as the configured LLM format + if not multipart then + local compatible, err = llm.is_compatible(request_table, route_type) + if not compatible then + kong.ctx.shared.skip_response_transformer = true + return bad_request(err) end end @@ -360,9 +400,14 @@ function _M:access(conf) end -- transform the body to Kong-format for this provider/model - local parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf.model, route_type) - if err then - return bad_request(err) + local parsed_request_body, content_type, err + if route_type ~= "preserve" and (not multipart) then + -- transform the body to Kong-format for this provider/model + parsed_request_body, content_type, err = ai_driver.to_format(request_table, conf_m.model, route_type) + if err then + kong.ctx.shared.skip_response_transformer = true + return bad_request(err) + end end -- execute pre-request hooks for "all" drivers before set new body @@ -371,11 +416,14 @@ function _M:access(conf) return bad_request(err) end - kong.service.request.set_body(parsed_request_body, content_type) + if route_type ~= "preserve" then + kong.service.request.set_body(parsed_request_body, content_type) + end -- now re-configure the request for this operation type local ok, err = ai_driver.configure_request(conf) if not ok then + kong.ctx.shared.skip_response_transformer = true return internal_server_error(err) end diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index a634230a81b..687a76e3961 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -119,7 +119,9 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - llama2_format = "ollama", + llama2_format = "raw", + top_p = 1, + top_k = 40, }, }, ["llm/v1/completions"] = { @@ -221,6 +223,166 @@ local expected_stream_choices = { } describe(PLUGIN_NAME .. ": (unit)", function() + it("resolves referenceable plugin configuration from request context", function() + local fake_request = { + ["get_header"] = function(header_name) + local headers = { + ["from_header_1"] = "header_value_here_1", + ["from_header_2"] = "header_value_here_2", + } + return headers[header_name] + end, + + ["get_uri_captures"] = function() + return { + ["named"] = { + ["uri_cap_1"] = "cap_value_here_1", + ["uri_cap_2"] = "cap_value_here_2", + }, + } + end, + + ["get_query_arg"] = function(query_arg_name) + local query_args = { + ["arg_1"] = "arg_value_here_1", + ["arg_2"] = "arg_value_here_2", + } + return query_args[query_arg_name] + end, + } + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "$(uri_captures.uri_cap_1)", + azure_deployment_id = "$(headers.from_header_1)", + azure_api_version = "$(query_params.arg_1)", + }, + }, + } + + local result, err = ai_shared.resolve_plugin_conf(fake_request, fake_config) + assert.is_falsy(err) + assert.same(result.model.options, { + ['azure_api_version'] = 'arg_value_here_1', + ['azure_deployment_id'] = 'header_value_here_1', + ['azure_instance'] = 'cap_value_here_1', + ['max_tokens'] = 256, + ['temperature'] = 1, + }) + end) + + it("resolves referenceable model name from request context", function() + local fake_request = { + ["get_header"] = function(header_name) + local headers = { + ["from_header_1"] = "header_value_here_1", + ["from_header_2"] = "header_value_here_2", + } + return headers[header_name] + end, + + ["get_uri_captures"] = function() + return { + ["named"] = { + ["uri_cap_1"] = "cap_value_here_1", + ["uri_cap_2"] = "cap_value_here_2", + }, + } + end, + + ["get_query_arg"] = function(query_arg_name) + local query_args = { + ["arg_1"] = "arg_value_here_1", + ["arg_2"] = "arg_value_here_2", + } + return query_args[query_arg_name] + end, + } + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "$(uri_captures.uri_cap_2)", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "string-1", + azure_deployment_id = "string-2", + azure_api_version = "string-3", + }, + }, + } + + local result, err = ai_shared.resolve_plugin_conf(fake_request, fake_config) + assert.is_falsy(err) + assert.same("cap_value_here_2", result.model.name) + end) + + it("returns appropriate error when referenceable plugin configuration is missing from request context", function() + local fake_request = { + ["get_header"] = function(header_name) + local headers = { + ["from_header_1"] = "header_value_here_1", + ["from_header_2"] = "header_value_here_2", + } + return headers[header_name] + end, + + ["get_uri_captures"] = function() + return { + ["named"] = { + ["uri_cap_1"] = "cap_value_here_1", + ["uri_cap_2"] = "cap_value_here_2", + }, + } + end, + + ["get_query_arg"] = function(query_arg_name) + local query_args = { + ["arg_1"] = "arg_value_here_1", + ["arg_2"] = "arg_value_here_2", + } + return query_args[query_arg_name] + end, + } + + local fake_config = { + route_type = "llm/v1/chat", + auth = { + header_name = "api-key", + header_value = "azure-key", + }, + model = { + name = "gpt-3.5-turbo", + provider = "azure", + options = { + max_tokens = 256, + temperature = 1.0, + azure_instance = "$(uri_captures.uri_cap_3)", + azure_deployment_id = "$(headers.from_header_1)", + azure_api_version = "$(query_params.arg_1)", + }, + }, + } + + local _, err = ai_shared.resolve_plugin_conf(fake_request, fake_config) + assert.same("uri_captures key uri_cap_3 was not provided", err) + end) it("llm/v1/chat message is compatible with llm/v1/chat route", function() local compatible, err = llm.is_compatible(SAMPLE_LLM_V1_CHAT, "llm/v1/chat") diff --git a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua index c81d8ab1255..b7a55183dca 100644 --- a/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/02-openai_integration_spec.lua @@ -184,6 +184,33 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } } + location = "/llm/v1/embeddings/good" { + content_by_lua_block { + local pl_file = require "pl.file" + local json = require("cjson.safe") + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + local token = ngx.req.get_headers()["authorization"] + local token_query = ngx.req.get_uri_args()["apikey"] + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then + ngx.req.read_body() + local body, err = ngx.req.get_body_data() + body, err = json.decode(body) + if err then + ngx.status = 400 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) + elseif body.input == "The food was delicious and the waiter" + and body.model == "text-embedding-ada-002" then + ngx.status = 200 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-embeddings/responses/good.json")) + end + else + ngx.status = 401 + ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/unauthorized.json")) + end + } + } } ]] @@ -403,6 +430,74 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then } -- + -- 200 embeddings (preserve route mode) good + local chat_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/embeddings/good" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good.id }, + config = { + route_type = "preserve", + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + model = { + provider = "openai", + options = { + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/embeddings/good" + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = chat_good.id }, + config = { + path = "/dev/stdout", + }, + } + -- + + -- 200 chat good but no model set + local chat_good = assert(bp.routes:insert { + service = empty_service, + protocols = { "http" }, + strip_path = true, + paths = { "/openai/llm/v1/chat/good-no-model-param" } + }) + bp.plugins:insert { + name = PLUGIN_NAME, + route = { id = chat_good.id }, + config = { + route_type = "llm/v1/chat", + auth = { + header_name = "Authorization", + header_value = "Bearer openai-key", + }, + model = { + provider = "openai", + options = { + max_tokens = 256, + temperature = 1.0, + upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/good" + }, + }, + }, + } + bp.plugins:insert { + name = "file-log", + route = { id = chat_good.id }, + config = { + path = "/dev/stdout", + }, + } + -- + -- 200 completions good using post body key local completions_good_post_body_key = assert(bp.routes:insert { service = empty_service, @@ -718,23 +813,6 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_truthy(json.error) assert.equals(json.error.code, "invalid_api_key") end) - - it("tries to override model", function() - local r = client:get("/openai/llm/v1/chat/good", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", - }, - body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"), - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - -- check this is in the 'kong' response format - assert.is_truthy(json.error) - assert.equals(json.error.message, "cannot use own model for this instance") - end) end) describe("openai llm/v1/chat", function() diff --git a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua index 43067360b18..c3cdc525c61 100644 --- a/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua @@ -522,23 +522,6 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_truthy(json.error) assert.equals(json.error.type, "authentication_error") end) - - it("tries to override model", function() - local r = client:get("/anthropic/llm/v1/chat/good", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", - }, - body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good_own_model.json"), - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - -- check this is in the 'kong' response format - assert.is_truthy(json.error) - assert.equals(json.error.message, "cannot use own model for this instance") - end) end) describe("anthropic llm/v1/chat", function() diff --git a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua index 33023874373..721cf97566e 100644 --- a/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/04-cohere_integration_spec.lua @@ -398,23 +398,6 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then -- check this is in the 'kong' response format assert.equals(json.message, "invalid api token") end) - - it("tries to override model", function() - local r = client:get("/cohere/llm/v1/chat/good", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", - }, - body = pl_file.read("spec/fixtures/ai-proxy/cohere/llm-v1-chat/requests/good_own_model.json"), - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - -- check this is in the 'kong' response format - assert.is_truthy(json.error) - assert.equals(json.error.message, "cannot use own model for this instance") - end) end) describe("cohere llm/v1/chat", function() diff --git a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua index 96d9645a401..a8efe9b21a1 100644 --- a/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/05-azure_integration_spec.lua @@ -413,23 +413,6 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then assert.is_truthy(json.error) assert.equals(json.error.code, "invalid_api_key") end) - - it("tries to override model", function() - local r = client:get("/azure/llm/v1/chat/good", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", - }, - body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"), - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - -- check this is in the 'kong' response format - assert.is_truthy(json.error) - assert.equals(json.error.message, "cannot use own model for this instance") - end) end) describe("azure llm/v1/chat", function() diff --git a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua index 49612408f1d..3c711cd83b4 100644 --- a/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/06-mistral_integration_spec.lua @@ -320,25 +320,6 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then if client then client:close() end end) - describe("mistral general", function() - it("tries to override model", function() - local r = client:get("/mistral/llm/v1/chat/good", { - headers = { - ["content-type"] = "application/json", - ["accept"] = "application/json", - }, - body = pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/requests/good_own_model.json"), - }) - - local body = assert.res_status(400, r) - local json = cjson.decode(body) - - -- check this is in the 'kong' response format - assert.is_truthy(json.error) - assert.equals(json.error.message, "cannot use own model for this instance") - end) - end) - describe("mistral llm/v1/chat", function() it("good request", function() local r = client:get("/mistral/llm/v1/chat/good", { diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json index bc7368bb7d4..31c90e2be24 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/azure/llm-v1-completions.json @@ -3,5 +3,6 @@ "model": "gpt-3.5-turbo-instruct", "max_tokens": 512, "temperature": 0.5, - "stream": false + "stream": false, + "top_p": 1 } diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-chat.json index e299158374e..71c517de499 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-chat.json @@ -4,6 +4,7 @@ "max_new_tokens": 512, "temperature": 0.5, "top_k": 40, - "top_p": 1 + "top_p": 1, + "stream": false } } \ No newline at end of file diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-completions.json index 72403f18e25..d4011887a6d 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/llama2/raw/llm-v1-completions.json @@ -3,7 +3,6 @@ "parameters": { "max_new_tokens": 512, "temperature": 0.5, - "top_k": 40, - "top_p": 1 + "stream": false } } \ No newline at end of file From 6db54a25e311ba79c27a6dab4f8bf7fae759eb57 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 22 Apr 2024 19:02:10 +0100 Subject: [PATCH 04/24] fix(ai-proxy): azure missing url override --- kong/llm/drivers/anthropic.lua | 2 -- kong/llm/drivers/azure.lua | 6 ++--- kong/llm/drivers/shared.lua | 21 --------------- kong/llm/init.lua | 27 ------------------- kong/plugins/ai-proxy/handler.lua | 11 ++++---- .../01-transformer_spec.lua | 2 +- 6 files changed, 10 insertions(+), 59 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 8300db58056..7b7b12be38d 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -206,8 +206,6 @@ local function handle_stream_event(event_t, model_info, route_type) -- last few frames / iterations if event_data and event_data.usage then - local meta = event_data.usage - return nil, nil, { prompt_tokens = nil, completion_tokens = event_data.meta.usage diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index a690a2593e2..1cf8e801042 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -53,7 +53,8 @@ function _M.subrequest(body, conf, http_opts, return_res_table) end -- azure has non-standard URL format - local url = fmt( + local url = (conf.model.options and conf.model.options.upstream_url) + or fmt( "%s%s?api-version=%s", ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id), conf.model.options @@ -99,7 +100,7 @@ function _M.configure_request(conf) local parsed_url if conf.model.options.upstream_url then - parsed_url = socket_url.parse(conf.model.options.upstream_url) + parsed_url = socket_url.parse("http://127.0.0.1:8080/jack/t") else -- azure has non-standard URL format local url = fmt( @@ -119,7 +120,6 @@ function _M.configure_request(conf) kong.service.request.set_scheme(parsed_url.scheme) kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443)) - local auth_header_name = conf.auth and conf.auth.header_name local auth_header_value = conf.auth and conf.auth.header_value local auth_param_name = conf.auth and conf.auth.param_name diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b5f45853575..f4121f2e65f 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -12,7 +12,6 @@ local utils = require("kong.tools.utils") -- static local str_find = string.find local str_sub = string.sub -local tbl_insert = table.insert local string_match = string.match local split = utils.split @@ -545,26 +544,6 @@ function _M.http_request(url, body, method, headers, http_opts, buffered) end end -local function get_token_text(event_t) - -- chat - return - event_t and - event_t.choices and - #event_t.choices > 0 and - event_t.choices[1].delta and - event_t.choices[1].delta.content - - or - - -- completions - event_t and - event_t.choices and - #event_t.choices > 0 and - event_t.choices[1].text - - or "" -end - -- Function to count the number of words in a string local function count_words(str) local count = 0 diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 57452c2d7d9..6d08274a17e 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -3,11 +3,7 @@ local typedefs = require("kong.db.schema.typedefs") local fmt = string.format local cjson = require("cjson.safe") local re_match = ngx.re.match -local buf = require("string.buffer") -local lower = string.lower -local meta = require "kong.meta" local ai_shared = require("kong.llm.drivers.shared") -local strip = require("kong.tools.utils").strip -- local _M = {} @@ -230,15 +226,6 @@ _M.config_schema = { }, } -local streaming_skip_headers = { - ["connection"] = true, - ["content-type"] = true, - ["keep-alive"] = true, - ["set-cookie"] = true, - ["transfer-encoding"] = true, - ["via"] = true, -} - local formats_compatible = { ["llm/v1/chat"] = { ["llm/v1/chat"] = true, @@ -248,20 +235,6 @@ local formats_compatible = { }, } -local function bad_request(msg) - ngx.log(ngx.WARN, msg) - ngx.status = 400 - ngx.header["Content-Type"] = "application/json" - ngx.say(cjson.encode({ error = { message = msg } })) -end - -local function internal_server_error(msg) - ngx.log(ngx.ERR, msg) - ngx.status = 500 - ngx.header["Content-Type"] = "application/json" - ngx.say(cjson.encode({ error = { message = msg } })) -end - local function identify_request(request) -- primitive request format determination local formats = {} diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 7a5cd20f012..9bd9928af04 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -2,7 +2,6 @@ local _M = {} -- imports local ai_shared = require("kong.llm.drivers.shared") -local ai_module = require("kong.llm") local llm = require("kong.llm") local cjson = require("cjson.safe") local kong_utils = require("kong.tools.gzip") @@ -78,7 +77,7 @@ local function handle_streaming_frame(conf) -- because we have already 200 OK'd the client by now if (not finished) and (is_gzip) then - event = kong_utils.inflate_gzip(chunk) + chunk = kong_utils.inflate_gzip(chunk) end local events = ai_shared.frame_to_events(chunk) @@ -86,7 +85,9 @@ local function handle_streaming_frame(conf) for _, event in ipairs(events) do local formatted, _, metadata = ai_driver.from_format(event, conf.model, "stream/" .. conf.route_type) - local event_t, token_t, err + local event_t = nil + local token_t = nil + local err if formatted then -- only stream relevant frames back to the user if conf.logging and conf.logging.log_payloads and (formatted ~= "[DONE]") then @@ -140,7 +141,6 @@ local function handle_streaming_frame(conf) end local response_frame = framebuffer:get() - framebuffer = nil if (not finished) and (is_gzip) then response_frame = kong_utils.deflate_gzip(response_frame) end @@ -233,6 +233,8 @@ function _M:body_filter(conf) return end + local route_type = conf.route_type + if kong.ctx.shared.skip_response_transformer and (route_type ~= "preserve") then local response_body if kong.ctx.shared.parsed_response then @@ -251,7 +253,6 @@ function _M:body_filter(conf) end local ai_driver = require("kong.llm.drivers." .. conf.model.provider) - local route_type = conf.route_type local new_response_string, err = ai_driver.from_format(response_body, conf.model, route_type) if err then diff --git a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua index 227de9553f4..9ccd084d031 100644 --- a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua @@ -64,7 +64,7 @@ local FORMATS = { options = { max_tokens = 512, temperature = 0.5, - upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/azure" + upstream_url = "http://" .. helpers.mock_upstream_host .. ":" .. MOCK_PORT .. "/chat/azure", }, }, auth = { From 460227754248cc2437b168adb9d394b70ea77141 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Mon, 22 Apr 2024 19:15:45 +0100 Subject: [PATCH 05/24] feat(ai-proxy): changelog --- changelog/unreleased/kong/ai-proxy-client-params.yml | 6 ++++++ changelog/unreleased/kong/ai-proxy-preserve-mode.yml | 6 ++++++ 2 files changed, 12 insertions(+) create mode 100644 changelog/unreleased/kong/ai-proxy-client-params.yml create mode 100644 changelog/unreleased/kong/ai-proxy-preserve-mode.yml diff --git a/changelog/unreleased/kong/ai-proxy-client-params.yml b/changelog/unreleased/kong/ai-proxy-client-params.yml new file mode 100644 index 00000000000..2d76256a190 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-client-params.yml @@ -0,0 +1,6 @@ +message: | + AI Proxy now reads most prompt tuning parameters from the client, whilst the + plugin config 'model options' are now just defaults. This fixes support for + using the respective provider's native SDK. +type: feature +scope: Plugin diff --git a/changelog/unreleased/kong/ai-proxy-preserve-mode.yml b/changelog/unreleased/kong/ai-proxy-preserve-mode.yml new file mode 100644 index 00000000000..ff7af94add2 --- /dev/null +++ b/changelog/unreleased/kong/ai-proxy-preserve-mode.yml @@ -0,0 +1,6 @@ +message: | + AI Proxy now has a 'preserve' route_type option, where the requests and responses + are passed directly to the upstream LLM. This is to enable compatilibity with any + and all models and SDKs, that may be used when calling the AI services. +type: feature +scope: Plugin From 64f6a28969ab24e7067d087c333d338f246634af Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Tue, 23 Apr 2024 16:02:53 +0100 Subject: [PATCH 06/24] fix)(ai-proxy): #12903 fixes rollup --- kong/llm/drivers/anthropic.lua | 90 +++++++++---------- kong/llm/drivers/azure.lua | 11 ++- kong/llm/drivers/cohere.lua | 4 +- kong/llm/drivers/llama2.lua | 3 +- kong/llm/drivers/mistral.lua | 4 +- kong/llm/drivers/openai.lua | 4 +- .../09-streaming_integration_spec.lua | 30 +++---- 7 files changed, 71 insertions(+), 75 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 7b7b12be38d..8ca0a971003 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -6,7 +6,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" local buffer = require("string.buffer") -local ensure_valid_path = require("kong.tools.utils").ensure_valid_path +local string_gsub = string.gsub -- -- globals @@ -186,60 +186,58 @@ local function start_to_event(event_data, model_info) end local function handle_stream_event(event_t, model_info, route_type) - ngx.log(ngx.WARN, event_t.event or "NO EVENT") - ngx.log(ngx.WARN, event_t.data or "NO DATA") - local event_id = event_t.event or "ping" + local event_id = event_t.event local event_data = cjson.decode(event_t.data) - if event_id and event_data then - if event_id == "message_start" then - -- message_start and contains the token usage and model metadata + if not event_id or not event_data then + return nil, "transformation to stream event failed or empty stream event received", nil + end - if event_data and event_data.message then - return start_to_event(event_data, model_info) - else - return nil, "message_start is missing the metadata block", nil - end + if event_id == "message_start" then + -- message_start and contains the token usage and model metadata - elseif event_id == "message_delta" then - -- message_delta contains and interim token count of the - -- last few frames / iterations - if event_data - and event_data.usage then - return nil, nil, { - prompt_tokens = nil, - completion_tokens = event_data.meta.usage - and event_data.meta.usage.output_tokens - or nil, - stop_reason = event_data.delta - and event_data.delta.stop_reason - or nil, - stop_sequence = event_data.delta - and event_data.delta.stop_sequence - or nil, - } - else - return nil, "message_delta is missing the metadata block", nil - end + if event_data and event_data.message then + return start_to_event(event_data, model_info) + else + return nil, "message_start is missing the metadata block", nil + end - elseif event_id == "content_block_start" then - -- content_block_start is just an empty string and indicates - -- that we're getting an actual answer - return delta_to_event(event_data, model_info) + elseif event_id == "message_delta" then + -- message_delta contains and interim token count of the + -- last few frames / iterations + if event_data + and event_data.usage then + return nil, nil, { + prompt_tokens = nil, + completion_tokens = event_data.meta.usage + and event_data.meta.usage.output_tokens + or nil, + stop_reason = event_data.delta + and event_data.delta.stop_reason + or nil, + stop_sequence = event_data.delta + and event_data.delta.stop_sequence + or nil, + } + else + return nil, "message_delta is missing the metadata block", nil + end - elseif event_id == "content_block_delta" then - return delta_to_event(event_data, model_info) + elseif event_id == "content_block_start" then + -- content_block_start is just an empty string and indicates + -- that we're getting an actual answer + return delta_to_event(event_data, model_info) - elseif event_id == "message_stop" then - return "[DONE]", nil, nil + elseif event_id == "content_block_delta" then + return delta_to_event(event_data, model_info) - elseif event_id == "ping" then - return nil, nil, nil + elseif event_id == "message_stop" then + return "[DONE]", nil, nil - end - end + elseif event_id == "ping" then + return nil, nil, nil - return nil, "transformation to stream event failed or empty stream event received", nil + end end local transformers_from = { @@ -467,7 +465,7 @@ function _M.configure_request(conf) end -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index 1cf8e801042..f38c6b4029b 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -6,7 +6,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" -local ensure_valid_path = require("kong.tools.utils").ensure_valid_path +local string_gsub = string.gsub -- -- globals @@ -60,7 +60,7 @@ function _M.subrequest(body, conf, http_opts, return_res_table) conf.model.options and conf.model.options.upstream_path or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path, - conf.model.options.azure_api_version or "2023-05-15" + conf.model.options.azure_api_version ) local method = ai_shared.operation_map[DRIVER_NAME][conf.route_type].method @@ -100,7 +100,7 @@ function _M.configure_request(conf) local parsed_url if conf.model.options.upstream_url then - parsed_url = socket_url.parse("http://127.0.0.1:8080/jack/t") + parsed_url = socket_url.parse(conf.model.options.upstream_url) else -- azure has non-standard URL format local url = fmt( @@ -113,8 +113,8 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(url) end - -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) + -- if the path is read from a URL capture, 3re that it is valid + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) @@ -135,7 +135,6 @@ function _M.configure_request(conf) -- technically min supported version query_table["api-version"] = kong.request.get_query_arg("api-version") or (conf.model.options and conf.model.options.azure_api_version) - or "2023-05-15" if auth_param_name and auth_param_value and auth_param_location == "query" then query_table[auth_param_name] = auth_param_value diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 30ca9f61d5d..717110c02bd 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -6,7 +6,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" local table_new = require("table.new") -local ensure_valid_path = require("kong.tools.utils").ensure_valid_path +local string_gsub = string.gsub -- -- globals @@ -466,7 +466,7 @@ function _M.configure_request(conf) end -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index f2a951024ff..1a3e844b578 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -8,7 +8,6 @@ local ai_shared = require("kong.llm.drivers.shared") local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" local string_gsub = string.gsub -local ensure_valid_path = require("kong.tools.utils").ensure_valid_path -- -- globals @@ -263,7 +262,7 @@ function _M.configure_request(conf) local parsed_url = socket_url.parse(conf.model.options.upstream_url) -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index fe814a67da2..8bcc2b9a33a 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -7,7 +7,7 @@ local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local openai_driver = require("kong.llm.drivers.openai") local socket_url = require "socket.url" -local ensure_valid_path = require("kong.tools.utils").ensure_valid_path +local string_gsub = string.gsub -- -- globals @@ -146,7 +146,7 @@ function _M.configure_request(conf) local parsed_url = socket_url.parse(conf.model.options.upstream_url) -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 4830a937f25..c7ede088a56 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -5,7 +5,7 @@ local cjson = require("cjson.safe") local fmt = string.format local ai_shared = require("kong.llm.drivers.shared") local socket_url = require "socket.url" -local ensure_valid_path = require("kong.tools.utils").ensure_valid_path +local string_gsub = string.gsub -- -- globals @@ -221,7 +221,7 @@ function _M.configure_request(conf) end -- if the path is read from a URL capture, ensure that it is valid - parsed_url.path = ensure_valid_path(parsed_url.path) + parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") kong.service.request.set_path(parsed_url.path) kong.service.request.set_scheme(parsed_url.scheme) diff --git a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua index cef707bbecf..7dc325d8f8f 100644 --- a/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua +++ b/spec/03-plugins/38-ai-proxy/09-streaming_integration_spec.lua @@ -53,32 +53,32 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then [5] = 'data: { "choices": [ { "delta": {}, "finish_reason": "stop", "index": 0, "logprobs": null } ], "created": 1712538905, "id": "chatcmpl-9BXtBvU8Tsw1U7CarzV71vQEjvYwq", "model": "gpt-4-0613", "object": "chat.completion.chunk", "system_fingerprint": null}', [6] = 'data: [DONE]', } - + local fmt = string.format local pl_file = require "pl.file" local json = require("cjson.safe") - + ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + local token = ngx.req.get_headers()["authorization"] local token_query = ngx.req.get_uri_args()["apikey"] - + if token == "Bearer openai-key" or token_query == "openai-key" or body.apikey == "openai-key" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) else -- GOOD RESPONSE - + ngx.status = 200 ngx.header["Content-Type"] = "text/event-stream" - + for i, EVENT in ipairs(_EVENT_CHUNKS) do ngx.print(fmt("%s\n\n", EVENT)) end @@ -123,7 +123,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -170,8 +170,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then [7] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" ="} }', [8] = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" 2"} }', [9] = 'data: {"type":"content_block_stop","index":0 }', - [10] = '{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":9}}', - [11] = '{"type":"message_stop"}', + [10] = 'data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":9}}', + [11] = 'data: {"type":"message_stop"}', } local fmt = string.format @@ -189,16 +189,16 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) else -- GOOD RESPONSE - + ngx.status = 200 ngx.header["Content-Type"] = "text/event-stream" - + for i, EVENT in ipairs(_EVENT_CHUNKS) do ngx.print(fmt("%s\n", EVENT)) ngx.print(fmt("%s\n\n", _DATA_CHUNKS[i])) @@ -228,7 +228,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then ngx.req.read_body() local body, err = ngx.req.get_body_data() body, err = json.decode(body) - + if err or (body.messages == ngx.null) then ngx.status = 400 ngx.print(pl_file.read("spec/fixtures/ai-proxy/openai/llm-v1-chat/responses/bad_request.json")) @@ -609,7 +609,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then end until not buffer - assert.equal(#events, 7) + assert.equal(#events, 8) assert.equal(buf:tostring(), "1 + 1 = 2") end) From bc251f972fad9129d0b5705e395a28410801160c Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 24 Apr 2024 07:25:24 +0100 Subject: [PATCH 07/24] feat(ai-proxy): make merge config defaults a shared feature --- kong/llm/drivers/anthropic.lua | 2 + kong/llm/drivers/cohere.lua | 59 ++-- kong/llm/drivers/llama2.lua | 10 +- kong/llm/drivers/mistral.lua | 2 + kong/llm/drivers/openai.lua | 12 +- kong/llm/drivers/shared.lua | 32 +- spec/03-plugins/38-ai-proxy/01-unit_spec.lua | 275 +++++++++++------- .../expected-requests/cohere/llm-v1-chat.json | 3 +- .../cohere/llm-v1-completions.json | 2 - 9 files changed, 243 insertions(+), 154 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 8ca0a971003..8d5a791657f 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -356,6 +356,8 @@ function _M.to_format(request_table, model_info, route_type) return request_table, nil, nil end + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + if not transformers_to[route_type] then return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) end diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 717110c02bd..f31608a5b8a 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -11,6 +11,12 @@ local string_gsub = string.gsub -- globals local DRIVER_NAME = "cohere" + +local _CHAT_ROLES = { + ["system"] = "CHATBOT", + ["assistant"] = "CHATBOT", + ["user"] = "USER", +} -- local function handle_stream_event(event_t, model_info, route_type) @@ -128,25 +134,22 @@ local function handle_stream_event(event_t, model_info, route_type) end -local function merge_fields(request_table, model) - model.options = model.options or {} - request_table.temperature = request_table.temperature or model.options.temperature - request_table.max_tokens = request_table.max_tokens or model.options.max_tokens - request_table.truncate = request_table.truncate or "END" - request_table.return_likelihoods = request_table.return_likelihoods or "NONE" - request_table.p = request_table.top_p or model.options.top_p - request_table.k = request_table.top_k or model.options.top_k - - return request_table -end - -local function handle_all(request_table, model) +local function handle_json_inference_event(request_table, model) + request_table.temperature = request_table.temperature + request_table.max_tokens = request_table.max_tokens + + request_table.p = request_table.top_p + request_table.k = request_table.top_k + + request_table.top_p = nil + request_table.top_k = nil + request_table.model = model.name or request_table.model request_table.stream = request_table.stream or false -- explicitly set this - + if request_table.prompt and request_table.messages then return kong.response.exit(400, "cannot run a 'prompt' and a history of 'messages' at the same time - refer to schema") - + elseif request_table.messages then -- we have to move all BUT THE LAST message into "chat_history" array -- and move the LAST message (from 'user') into "message" string @@ -156,40 +159,37 @@ local function handle_all(request_table, model) -- if this is the last message prompt, don't add to history if i < #request_table.messages then local role - if v.role == "assistant" or v.role == "CHATBOT" then - role = "CHATBOT" + if v.role == "assistant" or v.role == _CHAT_ROLES.assistant then + role = _CHAT_ROLES.assistant else - role = "USER" + role = _CHAT_ROLES.user end - + chat_history[i] = { role = role, message = v.content, } end end - + request_table.chat_history = chat_history end - + request_table.message = request_table.messages[#request_table.messages].content request_table.messages = nil - request_table = merge_fields(request_table, model) - + elseif request_table.prompt then request_table.prompt = request_table.prompt request_table.messages = nil request_table.message = nil - request_table = merge_fields(request_table, model) - end - + return request_table, "application/json", nil end local transformers_to = { - ["llm/v1/chat"] = handle_all, - ["llm/v1/completions"] = handle_all, + ["llm/v1/chat"] = handle_json_inference_event, + ["llm/v1/completions"] = handle_json_inference_event, } local transformers_from = { @@ -362,6 +362,8 @@ function _M.to_format(request_table, model_info, route_type) return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) end + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + local ok, response_object, content_type, err = pcall( transformers_to[route_type], request_table, @@ -492,5 +494,4 @@ function _M.configure_request(conf) return true, nil end - return _M diff --git a/kong/llm/drivers/llama2.lua b/kong/llm/drivers/llama2.lua index 1a3e844b578..b2c41db7776 100644 --- a/kong/llm/drivers/llama2.lua +++ b/kong/llm/drivers/llama2.lua @@ -108,10 +108,10 @@ end local function to_raw(request_table, model) local messages = {} messages.parameters = {} - messages.parameters.max_new_tokens = request_table.max_tokens or (model.options and model.options.max_tokens) - messages.parameters.top_p = request_table.top_p or (model.options and model.options.top_p) - messages.parameters.top_k = request_table.top_k or (model.options and model.options.top_k) - messages.parameters.temperature = request_table.temperature or (model.options and model.options.temperature) + messages.parameters.max_new_tokens = request_table.max_tokens + messages.parameters.top_p = request_table.top_p + messages.parameters.top_k = request_table.top_k + messages.parameters.temperature = request_table.temperature messages.parameters.stream = request_table.stream or false -- explicitly set this if request_table.prompt and request_table.messages then @@ -179,6 +179,8 @@ function _M.to_format(request_table, model_info, route_type) return openai_driver.to_format(request_table, model_info, route_type) end + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + -- dynamically call the correct transformer local ok, response_object, content_type, err = pcall( transformers_to[fmt("%s/%s", route_type, model_info.options.llama2_format)], diff --git a/kong/llm/drivers/mistral.lua b/kong/llm/drivers/mistral.lua index 8bcc2b9a33a..8a98ef6258f 100644 --- a/kong/llm/drivers/mistral.lua +++ b/kong/llm/drivers/mistral.lua @@ -66,6 +66,8 @@ function _M.to_format(request_table, model_info, route_type) return nil, nil, fmt("no transformer available to format %s://%s", model_info.provider, transformer_type) end + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + -- dynamically call the correct transformer local ok, response_object, content_type, err = pcall( transformers_to[transformer_type], diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index c7ede088a56..7535a3252f4 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -24,21 +24,12 @@ local _MERGE_PROPERTIES = { [4] = "top_k", } -local function merge_defaults(request, options) - for i, v in ipairs(_MERGE_PROPERTIES) do - request[v] = request[v] or (options and options[v]) or nil - end - - return request -end - local function handle_stream_event(event_t) return event_t.data end local transformers_to = { ["llm/v1/chat"] = function(request_table, model_info, route_type) - request_table = merge_defaults(request_table, model_info.options) request_table.model = request_table.model or model_info.name request_table.stream = request_table.stream or false -- explicitly set this @@ -46,7 +37,6 @@ local transformers_to = { end, ["llm/v1/completions"] = function(request_table, model_info, route_type) - request_table = merge_defaults(request_table, model_info.options) request_table.model = model_info.name request_table.stream = request_table.stream or false -- explicitly set this @@ -117,6 +107,8 @@ function _M.to_format(request_table, model_info, route_type) return nil, nil, fmt("no transformer for %s://%s", model_info.provider, route_type) end + request_table = ai_shared.merge_config_defaults(request_table, model_info.options, model_info.route_type) + local ok, response_object, content_type, err = pcall( transformers_to[route_type], request_table, diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index f4121f2e65f..b8ac9b7ff27 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -114,6 +114,30 @@ _M.clear_response_headers = { }, } +--- +-- Takes an already 'standardised' input, and merges +-- any missing fields with their defaults as defined +-- in the plugin config. +-- +-- It it supposed to be completely provider-agnostic, +-- and only operate to assist the Kong operator to +-- allow their users and admins to define a pre-runed +-- set of default options for any AI inference request. +-- +-- @param {table} request kong-format inference request conforming to one of many supported formats +-- @param {table} options the 'config.model.options' table from any Kong AI plugin +-- @return {table} the input 'request' table, but with (missing) default options merged in +-- @return {string} error if any is thrown - request should definitely be terminated if this is not nil +function _M.merge_config_defaults(request, options, request_format) + if options then + request.temperature = request.temperature or options.temperature + request.max_tokens = request.max_tokens or options.max_tokens + request.top_p = request.top_p or options.top_p + request.top_k = request.top_k or options.top_k + end + + return request, nil +end local function handle_stream_event(event_table, model_info, route_type) if event_table.done then @@ -228,10 +252,10 @@ function _M.to_ollama(request_table, model) if model.options then input.options = {} - input.options.num_predict = request_table.max_tokens or model.options.max_tokens - input.options.temperature = request_table.temperature or model.options.temperature - input.options.top_p = request_table.top_p or model.options.top_p - input.options.top_k = request_table.top_k or model.options.top_k + input.options.num_predict = request_table.max_tokens + input.options.temperature = request_table.temperature + input.options.top_p = request_table.top_p + input.options.top_k = request_table.top_k end return input, "application/json", nil diff --git a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua index 687a76e3961..9ff754a1407 100644 --- a/spec/03-plugins/38-ai-proxy/01-unit_spec.lua +++ b/spec/03-plugins/38-ai-proxy/01-unit_spec.lua @@ -19,6 +19,24 @@ local SAMPLE_LLM_V1_CHAT = { }, } +local SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS = { + messages = { + [1] = { + role = "system", + content = "You are a mathematician." + }, + [2] = { + role = "assistant", + content = "What is 1 + 1?" + }, + }, + max_tokens = 256, + temperature = 0.1, + top_p = 0.2, + some_extra_param = "string_val", + another_extra_param = 0.5, +} + local SAMPLE_DOUBLE_FORMAT = { messages = { [1] = { @@ -36,143 +54,172 @@ local SAMPLE_DOUBLE_FORMAT = { local FORMATS = { openai = { ["llm/v1/chat"] = { - name = "gpt-4", - provider = "openai", - options = { - max_tokens = 512, - temperature = 0.5, + config = { + name = "gpt-4", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + }, }, }, ["llm/v1/completions"] = { - name = "gpt-3.5-turbo-instruct", - provider = "openai", - options = { - max_tokens = 512, - temperature = 0.5, + config = { + name = "gpt-3.5-turbo-instruct", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + }, }, }, }, cohere = { ["llm/v1/chat"] = { - name = "command", - provider = "cohere", - options = { - max_tokens = 512, - temperature = 0.5, + config = { + name = "command", + provider = "cohere", + options = { + max_tokens = 512, + temperature = 0.5, + top_p = 1.0 + }, }, }, ["llm/v1/completions"] = { - name = "command", - provider = "cohere", - options = { - max_tokens = 512, - temperature = 0.5, - top_p = 0.75, - top_k = 5, + config = { + name = "command", + provider = "cohere", + options = { + max_tokens = 512, + temperature = 0.5, + top_p = 0.75, + top_k = 5, + }, }, }, }, anthropic = { ["llm/v1/chat"] = { - name = "claude-2.1", - provider = "anthropic", - options = { - max_tokens = 512, - temperature = 0.5, - top_p = 1.0, + config = { + name = "claude-2.1", + provider = "anthropic", + options = { + max_tokens = 512, + temperature = 0.5, + top_p = 1.0, + }, }, }, ["llm/v1/completions"] = { - name = "claude-2.1", - provider = "anthropic", - options = { - max_tokens = 512, - temperature = 0.5, - top_p = 1.0, + config = { + name = "claude-2.1", + provider = "anthropic", + options = { + max_tokens = 512, + temperature = 0.5, + top_p = 1.0, + }, }, }, }, azure = { ["llm/v1/chat"] = { - name = "gpt-4", - provider = "azure", - options = { - max_tokens = 512, - temperature = 0.5, - top_p = 1.0, + config = { + name = "gpt-4", + provider = "azure", + options = { + max_tokens = 512, + temperature = 0.5, + top_p = 1.0, + }, }, }, ["llm/v1/completions"] = { - name = "gpt-3.5-turbo-instruct", + config = { + name = "gpt-3.5-turbo-instruct", provider = "azure", - options = { - max_tokens = 512, - temperature = 0.5, - top_p = 1.0, + options = { + max_tokens = 512, + temperature = 0.5, + top_p = 1.0, + }, }, }, }, llama2_raw = { ["llm/v1/chat"] = { - name = "llama2", - provider = "llama2", - options = { - max_tokens = 512, - temperature = 0.5, - llama2_format = "raw", - top_p = 1, - top_k = 40, + config = { + name = "llama2", + provider = "llama2", + options = { + max_tokens = 512, + temperature = 0.5, + llama2_format = "raw", + top_p = 1, + top_k = 40, + }, }, }, ["llm/v1/completions"] = { - name = "llama2", - provider = "llama2", - options = { - max_tokens = 512, - temperature = 0.5, - llama2_format = "raw", + config = { + name = "llama2", + provider = "llama2", + options = { + max_tokens = 512, + temperature = 0.5, + llama2_format = "raw", + }, }, }, }, llama2_ollama = { ["llm/v1/chat"] = { - name = "llama2", - provider = "llama2", - options = { - max_tokens = 512, - temperature = 0.5, - llama2_format = "ollama", + config = { + name = "llama2", + provider = "llama2", + options = { + max_tokens = 512, + temperature = 0.5, + llama2_format = "ollama", + }, }, }, ["llm/v1/completions"] = { - name = "llama2", - provider = "llama2", - options = { - max_tokens = 512, - temperature = 0.5, - llama2_format = "ollama", + config = { + name = "llama2", + provider = "llama2", + options = { + max_tokens = 512, + temperature = 0.5, + llama2_format = "ollama", + }, }, }, }, mistral_openai = { ["llm/v1/chat"] = { - name = "mistral-tiny", - provider = "mistral", - options = { - max_tokens = 512, - temperature = 0.5, - mistral_format = "openai", + config = { + name = "mistral-tiny", + provider = "mistral", + options = { + max_tokens = 512, + temperature = 0.5, + mistral_format = "openai", + }, }, }, }, mistral_ollama = { ["llm/v1/chat"] = { - name = "mistral-tiny", - provider = "mistral", - options = { - max_tokens = 512, - temperature = 0.5, - mistral_format = "ollama", + config = { + name = "mistral-tiny", + provider = "mistral", + options = { + max_tokens = 512, + temperature = 0.5, + mistral_format = "ollama", + }, }, }, }, @@ -422,7 +469,7 @@ describe(PLUGIN_NAME .. ": (unit)", function() describe(k .. " format test", function() local actual_request_table - local driver = require("kong.llm.drivers." .. l.provider) + local driver = require("kong.llm.drivers." .. l.config.provider) -- what we do is first put the SAME request message from the user, through the converter, for this provider/format @@ -437,20 +484,20 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- send it local content_type, err - actual_request_table, content_type, err = driver.to_format(request_table, l, k) + actual_request_table, content_type, err = driver.to_format(request_table, l.config, k) assert.is_nil(err) assert.not_nil(content_type) -- load the expected outbound request to this provider local filename - if l.provider == "llama2" then - filename = fmt("spec/fixtures/ai-proxy/unit/expected-requests/%s/%s/%s.json", l.provider, l.options.llama2_format, pl_replace(k, "/", "-")) + if l.config.provider == "llama2" then + filename = fmt("spec/fixtures/ai-proxy/unit/expected-requests/%s/%s/%s.json", l.config.provider, l.config.options.llama2_format, pl_replace(k, "/", "-")) - elseif l.provider == "mistral" then - filename = fmt("spec/fixtures/ai-proxy/unit/expected-requests/%s/%s/%s.json", l.provider, l.options.mistral_format, pl_replace(k, "/", "-")) + elseif l.config.provider == "mistral" then + filename = fmt("spec/fixtures/ai-proxy/unit/expected-requests/%s/%s/%s.json", l.config.provider, l.config.options.mistral_format, pl_replace(k, "/", "-")) else - filename = fmt("spec/fixtures/ai-proxy/unit/expected-requests/%s/%s.json", l.provider, pl_replace(k, "/", "-")) + filename = fmt("spec/fixtures/ai-proxy/unit/expected-requests/%s/%s.json", l.config.provider, pl_replace(k, "/", "-")) end @@ -470,20 +517,20 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- load what the endpoint would really response with local filename - if l.provider == "llama2" then - filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.provider, l.options.llama2_format, pl_replace(k, "/", "-")) + if l.config.provider == "llama2" then + filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.config.provider, l.config.options.llama2_format, pl_replace(k, "/", "-")) - elseif l.provider == "mistral" then - filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.provider, l.options.mistral_format, pl_replace(k, "/", "-")) + elseif l.config.provider == "mistral" then + filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s/%s.json", l.config.provider, l.config.options.mistral_format, pl_replace(k, "/", "-")) else - filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s.json", l.provider, pl_replace(k, "/", "-")) + filename = fmt("spec/fixtures/ai-proxy/unit/real-responses/%s/%s.json", l.config.provider, pl_replace(k, "/", "-")) end local virtual_response_json = pl_file.read(filename) -- convert to kong format (emulate on response phase hook) - local actual_response_json, err = driver.from_format(virtual_response_json, l, k) + local actual_response_json, err = driver.from_format(virtual_response_json, l.config, k) assert.is_nil(err) local actual_response_table, err = cjson.decode(actual_response_json) @@ -491,14 +538,14 @@ describe(PLUGIN_NAME .. ": (unit)", function() -- load the expected response body local filename - if l.provider == "llama2" then - filename = fmt("spec/fixtures/ai-proxy/unit/expected-responses/%s/%s/%s.json", l.provider, l.options.llama2_format, pl_replace(k, "/", "-")) + if l.config.provider == "llama2" then + filename = fmt("spec/fixtures/ai-proxy/unit/expected-responses/%s/%s/%s.json", l.config.provider, l.config.options.llama2_format, pl_replace(k, "/", "-")) - elseif l.provider == "mistral" then - filename = fmt("spec/fixtures/ai-proxy/unit/expected-responses/%s/%s/%s.json", l.provider, l.options.mistral_format, pl_replace(k, "/", "-")) + elseif l.config.provider == "mistral" then + filename = fmt("spec/fixtures/ai-proxy/unit/expected-responses/%s/%s/%s.json", l.config.provider, l.config.options.mistral_format, pl_replace(k, "/", "-")) else - filename = fmt("spec/fixtures/ai-proxy/unit/expected-responses/%s/%s.json", l.provider, pl_replace(k, "/", "-")) + filename = fmt("spec/fixtures/ai-proxy/unit/expected-responses/%s/%s.json", l.config.provider, pl_replace(k, "/", "-")) end local expected_response_json = pl_file.read(filename) @@ -509,8 +556,6 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.same(expected_response_table.choices[1].message, actual_response_table.choices[1].message) assert.same(actual_response_table.model, expected_response_table.model) end) - - end) end end) @@ -578,4 +623,28 @@ describe(PLUGIN_NAME .. ": (unit)", function() assert.equal(err, "no transformer available to format mistral://llm/v1/chatnopenotsupported/ollama") end) + + it("produces a correct default config merge", function() + local formatted, err = ai_shared.merge_config_defaults( + SAMPLE_LLM_V1_CHAT_WITH_SOME_OPTS, + { + max_tokens = 1024, + top_p = 1.0, + }, + "llm/v1/chat" + ) + + formatted.messages = nil -- not needed for config merge + + assert.is_nil(err) + assert.same({ + max_tokens = 256, + temperature = 0.1, + top_p = 0.2, + some_extra_param = "string_val", + another_extra_param = 0.5, + }, formatted) + end) + + end) diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json index b22c3e520a4..46ae65625c8 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-chat.json @@ -10,7 +10,6 @@ "model": "command", "max_tokens": 512, "temperature": 0.5, - "truncate": "END", - "return_likelihoods": "NONE", + "p": 1.0, "stream": false } diff --git a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json index 9ba71345bad..400114b7e01 100644 --- a/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json +++ b/spec/fixtures/ai-proxy/unit/expected-requests/cohere/llm-v1-completions.json @@ -5,7 +5,5 @@ "temperature": 0.5, "p": 0.75, "k": 5, - "return_likelihoods": "NONE", - "truncate": "END", "stream": false } From ee642d140cb9d0888dea610e696e97f2ea85d0ac Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 24 Apr 2024 07:39:11 +0100 Subject: [PATCH 08/24] fix(ai-proxy): function docs --- kong/llm/drivers/shared.lua | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index b8ac9b7ff27..7afccc0a093 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -179,6 +179,13 @@ local function handle_stream_event(event_table, model_info, route_type) end end +--- +-- Splits up a string by delimiter, whilst preserving +-- empty lines, double line-breaks. +-- +-- @param {string} str input string to split +-- @param {string} delimiter delimeter (can be complex string) to split by +-- @return {table} n number of split results, or empty table local function complex_split(str, delimiter) local result = {} local from = 1 @@ -192,6 +199,20 @@ local function complex_split(str, delimiter) return result end +--- +-- Splits a HTTPS data chunk or frame into individual +-- SSE-format messages, see: +-- https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format +-- +-- For compatibility, it also looks for the first character being '{' which +-- indicates that the input is not text/event-stream format, but instead a chunk +-- of delimited application/json, which some providers return, in which case +-- it simply splits the frame into separate JSON messages and appends 'data: ' +-- as if it were an SSE message. +-- +-- @param {string} frame input string to format into SSE events +-- @param {string} delimiter delimeter (can be complex string) to split by +-- @return {table} n number of split SSE messages, or empty table function _M.frame_to_events(frame) local events = {} From fc147680e8be6737b35dec97f07b161b7d5e4f8d Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 24 Apr 2024 07:47:45 +0100 Subject: [PATCH 09/24] fix(lint): ai-proxy lint --- kong/llm/drivers/openai.lua | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index 7535a3252f4..c7e8b4a6016 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -12,18 +12,6 @@ local string_gsub = string.gsub local DRIVER_NAME = "openai" -- --- merge_defaults takes the model options, and sets any defaults defined, --- if the caller hasn't explicitly set them --- --- we have already checked that "max_tokens" isn't overridden when it --- is not allowed to do so. -local _MERGE_PROPERTIES = { - [1] = "max_tokens", - [2] = "temperature", - [3] = "top_p", - [4] = "top_k", -} - local function handle_stream_event(event_t) return event_t.data end From a26dfb93c54e8c046aeeb7cd979421f591e9c462 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 24 Apr 2024 16:23:12 +0100 Subject: [PATCH 10/24] feat(ai-proxy): add 3.7 compatibility checkers --- kong/clustering/compat/checkers.lua | 55 +++++++++ kong/llm/init.lua | 2 +- .../09-hybrid_mode/09-config-compat_spec.lua | 114 ++++++++++++++++++ 3 files changed, 170 insertions(+), 1 deletion(-) diff --git a/kong/clustering/compat/checkers.lua b/kong/clustering/compat/checkers.lua index 2a900f0728e..6c361dc853e 100644 --- a/kong/clustering/compat/checkers.lua +++ b/kong/clustering/compat/checkers.lua @@ -23,6 +23,61 @@ end local compatible_checkers = { + { 3007000000, --[[ 3.7.0.0 ]] + function(config_table, dp_version, log_suffix) + local has_update + + for _, plugin in ipairs(config_table.plugins or {}) do + if plugin.name == 'ai-proxy' then + local config = plugin.config + if config.model and config.model.options then + if config.model.options.response_streaming then + config.model.options.response_streaming = nil + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' response_streaming == nil, because it is not supported' .. + ' in this release', + dp_version, log_suffix) + has_update = true + end + + if config.model.options.upstream_path then + config.model.options.upstream_path = nil + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' upstream_path == nil, because it is not supported' .. + ' in this release', + dp_version, log_suffix) + has_update = true + end + end + + if config.route_type == "preserve" then + config.route_type = "llm/v1/chat" + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' route_type == "llm/v1/chat", because preserve' .. + ' mode is not supported in this release', + dp_version, log_suffix) + has_update = true + end + end + + if plugin.name == 'ai-request-transformer' or plugin.name == 'ai-response-transformer' then + local config = plugin.config + if config.llm.model + and config.llm.model.options + and config.llm.model.options.upstream_path then + config.llm.model.options.upstream_path = nil + log_warn_message('configures ' .. plugin.name .. ' plugin with' .. + ' upstream_path == nil, because it is not supported' .. + ' in this release', + dp_version, log_suffix) + has_update = true + end + end + end + + return has_update + end, + }, { 3006000000, --[[ 3.6.0.0 ]] function(config_table, dp_version, log_suffix) local has_update diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 6d08274a17e..8559b0306f6 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -49,7 +49,7 @@ local model_options_schema = { fields = { { response_streaming = { type = "string", - description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via WebSocket.", + description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server events.", required = true, default = "allow", one_of = { "allow", "deny", "always" } }}, diff --git a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua index b64af4cf5c6..96f41bfd03d 100644 --- a/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua +++ b/spec/02-integration/09-hybrid_mode/09-config-compat_spec.lua @@ -472,6 +472,120 @@ describe("CP/DP config compat transformations #" .. strategy, function() end) end) end) + + describe("ai plugins", function() + it("[ai-proxy] sets unsupported AI LLM properties to nil or defaults", function() + -- [[ 3.7.x ]] -- + local ai_proxy = admin.plugins:insert { + name = "ai-proxy", + enabled = true, + config = { + route_type = "preserve", -- becomes 'llm/v1/chat' + auth = { + header_name = "header", + header_value = "value", + }, + model = { + name = "any-model-name", + provider = "openai", + options = { + max_tokens = 512, + temperature = 0.5, + response_streaming = "allow", -- becomes nil + upstream_path = "/anywhere", -- becomes nil + }, + }, + }, + } + -- ]] + + local expected_ai_proxy_prior_37 = utils.cycle_aware_deep_copy(ai_proxy) + expected_ai_proxy_prior_37.config.model.options.response_streaming = nil + expected_ai_proxy_prior_37.config.model.options.upstream_path = nil + expected_ai_proxy_prior_37.config.route_type = "llm/v1/chat" + + do_assert(utils.uuid(), "3.6.0", expected_ai_proxy_prior_37) + + -- cleanup + admin.plugins:remove({ id = ai_proxy.id }) + end) + + it("[ai-request-transformer] sets unsupported AI LLM properties to nil or defaults", function() + -- [[ 3.7.x ]] -- + local ai_request_transformer = admin.plugins:insert { + name = "ai-request-transformer", + enabled = true, + config = { + prompt = "Convert my message to XML.", + llm = { + route_type = "llm/v1/chat", + auth = { + header_name = "header", + header_value = "value", + }, + model = { + name = "any-model-name", + provider = "azure", + options = { + azure_instance = "azure-1", + azure_deployment_id = "azdep-1", + azure_api_version = "2023-01-01", + max_tokens = 512, + temperature = 0.5, + upstream_path = "/anywhere", -- becomes nil + }, + }, + }, + }, + } + -- ]] + + local expected_ai_request_transformer_prior_37 = utils.cycle_aware_deep_copy(ai_request_transformer) + expected_ai_request_transformer_prior_37.config.llm.model.options.upstream_path = nil + + do_assert(utils.uuid(), "3.6.0", expected_ai_request_transformer_prior_37) + + -- cleanup + admin.plugins:remove({ id = ai_request_transformer.id }) + end) + + it("[ai-response-transformer] sets unsupported AI LLM properties to nil or defaults", function() + -- [[ 3.7.x ]] -- + local ai_response_transformer = admin.plugins:insert { + name = "ai-response-transformer", + enabled = true, + config = { + prompt = "Convert my message to XML.", + llm = { + route_type = "llm/v1/chat", + auth = { + header_name = "header", + header_value = "value", + }, + model = { + name = "any-model-name", + provider = "cohere", + options = { + azure_api_version = "2023-01-01", + max_tokens = 512, + temperature = 0.5, + upstream_path = "/anywhere", -- becomes nil + }, + }, + }, + }, + } + -- ]] + + local expected_ai_response_transformer_prior_37 = utils.cycle_aware_deep_copy(ai_response_transformer) + expected_ai_response_transformer_prior_37.config.llm.model.options.upstream_path = nil + + do_assert(utils.uuid(), "3.6.0", expected_ai_response_transformer_prior_37) + + -- cleanup + admin.plugins:remove({ id = ai_response_transformer.id }) + end) + end) end) end) From 8c0fba5efd15d0fd8bec8513e407097e304a1a35 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Wed, 24 Apr 2024 23:20:31 +0100 Subject: [PATCH 11/24] fix(ai-proxy): no longer require streaming conf option --- kong/llm/init.lua | 2 +- kong/plugins/ai-proxy/handler.lua | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/kong/llm/init.lua b/kong/llm/init.lua index 8559b0306f6..ef8310c835b 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -50,7 +50,7 @@ local model_options_schema = { { response_streaming = { type = "string", description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server events.", - required = true, + required = false, default = "allow", one_of = { "allow", "deny", "always" } }}, { max_tokens = { diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 9bd9928af04..5dda883a1ed 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -301,6 +301,7 @@ end function _M:access(conf) -- store the route_type in ctx for use in response parsing local route_type = conf.route_type + kong.ctx.plugin.operation = route_type local request_table @@ -317,8 +318,6 @@ function _M:access(conf) request_table = kong.request.get_body(content_type) - -- TODO octet stream check here - if not request_table then if not string.find(content_type, "multipart/form-data", nil, true) then return bad_request("content-type header does not match request body") From 28dd7c4400071ce1bb7c3e10ad46e3a60bad2c36 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 25 Apr 2024 11:56:46 +0100 Subject: [PATCH 12/24] fix(ai-proxy): latest fix rollup --- kong/llm/drivers/shared.lua | 35 ++----------------------------- kong/plugins/ai-proxy/handler.lua | 2 +- kong/tools/http.lua | 24 --------------------- 3 files changed, 3 insertions(+), 58 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 7afccc0a093..d74f163a6ea 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -47,6 +47,7 @@ local openai_override = os.getenv("OPENAI_TEST_PORT") _M.streaming_has_token_counts = { ["cohere"] = true, ["llama2"] = true, + ["anthropic"] = true, } _M.upstream_url_format = { @@ -179,26 +180,6 @@ local function handle_stream_event(event_table, model_info, route_type) end end ---- --- Splits up a string by delimiter, whilst preserving --- empty lines, double line-breaks. --- --- @param {string} str input string to split --- @param {string} delimiter delimeter (can be complex string) to split by --- @return {table} n number of split results, or empty table -local function complex_split(str, delimiter) - local result = {} - local from = 1 - local delim_from, delim_to = string.find(str, delimiter, from) - while delim_from do - table.insert( result, string.sub(str, from , delim_from-1)) - from = delim_to + 1 - delim_from, delim_to = string.find(str, delimiter, from) - end - table.insert( result, string.sub(str, from)) - return result -end - --- -- Splits a HTTPS data chunk or frame into individual -- SSE-format messages, see: @@ -225,7 +206,7 @@ function _M.frame_to_events(frame) } end else - local event_lines = complex_split(frame, "\n") + local event_lines = split(frame, "\n") local struct = { event = nil, id = nil, data = nil } for _, dat in ipairs(event_lines) do @@ -363,18 +344,6 @@ function _M.conf_from_request(kong_request, source, key) end end -function _M.conf_from_request(kong_request, source, key) - if source == "uri_captures" then - return kong_request.get_uri_captures().named[key] - elseif source == "headers" then - return kong_request.get_header(key) - elseif source == "query_params" then - return kong_request.get_query_arg(key) - else - return nil, "source '" .. source .. "' is not supported" - end -end - function _M.resolve_plugin_conf(kong_request, conf) local err local conf_m = utils.cycle_aware_deep_copy(conf) diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 5dda883a1ed..aace72068bd 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -149,7 +149,7 @@ local function handle_streaming_frame(conf) if finished then local fake_response_t = { - response = kong.ctx.plugin.ai_stream_log_buffer:get(), + response = kong.ctx.plugin.ai_stream_log_buffer and kong.ctx.plugin.ai_stream_log_buffer:get(), usage = { prompt_tokens = kong.ctx.plugin.ai_stream_prompt_tokens or 0, completion_tokens = kong.ctx.plugin.ai_stream_completion_tokens or 0, diff --git a/kong/tools/http.lua b/kong/tools/http.lua index de4b3fe8a30..34ca72ccdc2 100644 --- a/kong/tools/http.lua +++ b/kong/tools/http.lua @@ -553,28 +553,4 @@ do end end -do - local string_sub = string.sub - - --- - -- Ensures that a given path adheres to a valid format - -- for usage with PDK set_path, or a lua-resty-http client. - -- - -- The function returns the re-formatted path, in its valid form, - -- or returns the original string if nothing was changed. - -- - -- @param path string the path to ensure is valid - -- @return string the newly-formatted valid path, or the original path if nothing changed - function _M.ensure_valid_path(path) - if string_sub(path, 1, 1) ~= "/" then - path = "/" .. path - - elseif string_sub(path, 1, 2) == "//" then - path = string_sub(path, 2) - end - - return path - end -end - return _M From a1d432a134a697db37ae61e6433bcf7115b11eaa Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 25 Apr 2024 12:56:48 +0100 Subject: [PATCH 13/24] fix(ai-proxy): latest fix rollup --- kong/llm/init.lua | 2 +- kong/plugins/ai-proxy/handler.lua | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kong/llm/init.lua b/kong/llm/init.lua index ef8310c835b..6ac1a1ff0b9 100644 --- a/kong/llm/init.lua +++ b/kong/llm/init.lua @@ -49,7 +49,7 @@ local model_options_schema = { fields = { { response_streaming = { type = "string", - description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server events.", + description = "Whether to 'optionally allow', 'deny', or 'always' (force) the streaming of answers via server sent events.", required = false, default = "allow", one_of = { "allow", "deny", "always" } }}, diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index aace72068bd..2229c89c7b0 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -229,7 +229,7 @@ end function _M:body_filter(conf) -- if body_filter is called twice, then return - if kong.ctx.plugin.body_called and (not kong.ctx.shared.ai_proxy_streaming_mode) then + if kong.ctx.plugin.body_called and not kong.ctx.shared.ai_proxy_streaming_mode then return end From 35de2c753b4257411f2050c7e2b63c3f424101c8 Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 16:12:24 +0200 Subject: [PATCH 14/24] Add calculate cost logic --- kong/llm/drivers/shared.lua | 123 ++++++++++++++++++++---------------- 1 file changed, 70 insertions(+), 53 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index d74f163a6ea..a40a9031272 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -415,6 +415,10 @@ function _M.pre_request(conf, request_table) kong.log.set_serialize_value(log_entry_keys.REQUEST_BODY, kong.request.get_raw_body()) end + -- log tokens prompt for reports and billing + local prompt_tokens, err = calculate_cost(request_table, {}, 1.0) or 0 + kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens + return true, nil end @@ -436,70 +440,74 @@ function _M.post_request(conf, response_object) end -- analytics and logging - if conf.logging and conf.logging.log_statistics then - local provider_name = conf.model.provider + local provider_name = conf.model.provider - local plugin_name = conf.__key__:match('plugins:(.-):') - if not plugin_name or plugin_name == "" then - return nil, "no plugin name is being passed by the plugin" - end + local plugin_name = conf.__key__:match('plugins:(.-):') + if not plugin_name or plugin_name == "" then + return nil, "no plugin name is being passed by the plugin" + end - -- check if we already have analytics in this context - local request_analytics = kong.ctx.shared.analytics + -- check if we already have analytics in this context + local request_analytics = kong.ctx.shared.analytics - -- create a new structure if not - if not request_analytics then - request_analytics = {} - end + -- create a new structure if not + if not request_analytics then + request_analytics = {} + end - -- check if we already have analytics for this provider - local request_analytics_plugin = request_analytics[plugin_name] - - -- create a new structure if not - if not request_analytics_plugin then - request_analytics_plugin = { - [log_entry_keys.META_CONTAINER] = {}, - [log_entry_keys.PAYLOAD_CONTAINER] = {}, - [log_entry_keys.TOKENS_CONTAINER] = { - [log_entry_keys.PROMPT_TOKEN] = 0, - [log_entry_keys.COMPLETION_TOKEN] = 0, - [log_entry_keys.TOTAL_TOKENS] = 0, - }, - } - end + -- check if we already have analytics for this provider + local request_analytics_plugin = request_analytics[plugin_name] + + -- create a new structure if not + if not request_analytics_plugin then + request_analytics_plugin = { + [log_entry_keys.META_CONTAINER] = {}, + [log_entry_keys.PAYLOAD_CONTAINER] = {}, + [log_entry_keys.TOKENS_CONTAINER] = { + [log_entry_keys.PROMPT_TOKEN] = 0, + [log_entry_keys.COMPLETION_TOKEN] = 0, + [log_entry_keys.TOTAL_TOKENS] = 0, + }, + } + end - -- Set the model, response, and provider names in the current try context - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name - request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id + -- Set the model, response, and provider names in the current try context + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.REQUEST_MODEL] = conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.RESPONSE_MODEL] = response_object.model or conf.model.name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PROVIDER_NAME] = provider_name + request_analytics_plugin[log_entry_keys.META_CONTAINER][log_entry_keys.PLUGIN_ID] = conf.__plugin_id - -- Capture openai-format usage stats from the transformed response body - if response_object.usage then - if response_object.usage.prompt_tokens then - request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] + response_object.usage.prompt_tokens - end - if response_object.usage.completion_tokens then - request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] + response_object.usage.completion_tokens - end - if response_object.usage.total_tokens then - request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] + response_object.usage.total_tokens - end + -- Capture openai-format usage stats from the transformed response body + if response_object.usage then + if response_object.usage.prompt_tokens then + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.PROMPT_TOKEN] + response_object.usage.prompt_tokens end - - -- Log response body if logging payloads is enabled - if conf.logging and conf.logging.log_payloads then - request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string + if response_object.usage.completion_tokens then + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.COMPLETION_TOKEN] + response_object.usage.completion_tokens + end + if response_object.usage.total_tokens then + request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] = request_analytics_plugin[log_entry_keys.TOKENS_CONTAINER][log_entry_keys.TOTAL_TOKENS] + response_object.usage.total_tokens end + end - -- Update context with changed values - request_analytics[plugin_name] = request_analytics_plugin - kong.ctx.shared.analytics = request_analytics + -- Log response body if logging payloads is enabled + if conf.logging and conf.logging.log_payloads then + request_analytics_plugin[log_entry_keys.PAYLOAD_CONTAINER][log_entry_keys.RESPONSE_BODY] = body_string + end + + -- Update context with changed values + request_analytics[plugin_name] = request_analytics_plugin + kong.ctx.shared.analytics = request_analytics + if conf.logging and conf.logging.log_statistics then -- Log analytics data kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin) end + -- log tokens prompt for reports and billing + local response_tokens, err = calculate_cost(response_object, {}, 1.0) or 0 + kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens + return nil end @@ -597,11 +605,20 @@ local function count_prompt(content, tokens_factor) return count end -function _M.calculate_cost(query_body, tokens_models, tokens_factor) +local function calculate_cost(query_body, tokens_models, tokens_factor) local query_cost = 0 local err - if query_body.messages then + if query_body.choices then + -- Calculate the cost based on the content type + for _, choice in ipairs(query_body.choices) do + if choice.message.content then + query_cost = query_cost + (count_words(choice.message.content) * tokens_factor) + elseif choice.text then + query_cost = query_cost + (count_words(choice.text) * tokens_factor) + end + end + elseif query_body.messages then -- Calculate the cost based on the content type for _, message in ipairs(query_body.messages) do query_cost = query_cost + (count_words(message.content) * tokens_factor) @@ -618,7 +635,7 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor) -- Round the total cost quantified query_cost = math.floor(query_cost + 0.5) - + return query_cost end From bb845c2c84cd6253bbf32113196876e9383073f0 Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 16:19:54 +0200 Subject: [PATCH 15/24] Fix calculate cost function --- kong/llm/drivers/shared.lua | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index a40a9031272..048b0a06078 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -416,7 +416,7 @@ function _M.pre_request(conf, request_table) end -- log tokens prompt for reports and billing - local prompt_tokens, err = calculate_cost(request_table, {}, 1.0) or 0 + local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) or 0 kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens return true, nil @@ -505,7 +505,7 @@ function _M.post_request(conf, response_object) end -- log tokens prompt for reports and billing - local response_tokens, err = calculate_cost(response_object, {}, 1.0) or 0 + local response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) or 0 kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens return nil @@ -605,7 +605,7 @@ local function count_prompt(content, tokens_factor) return count end -local function calculate_cost(query_body, tokens_models, tokens_factor) +function _M.calculate_cost(query_body, tokens_models, tokens_factor) local query_cost = 0 local err @@ -635,7 +635,7 @@ local function calculate_cost(query_body, tokens_models, tokens_factor) -- Round the total cost quantified query_cost = math.floor(query_cost + 0.5) - + return query_cost end From 47235e0a057a5c1a1e72838e35b81c83e499d533 Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 16:26:02 +0200 Subject: [PATCH 16/24] Fix error --- kong/llm/drivers/shared.lua | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 048b0a06078..abde9832d46 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -417,6 +417,9 @@ function _M.pre_request(conf, request_table) -- log tokens prompt for reports and billing local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) or 0 + if err then + kong.log.warn("failed calculating cost for prompt tokens: ", err) + end kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens return true, nil @@ -504,8 +507,11 @@ function _M.post_request(conf, response_object) kong.log.set_serialize_value(fmt("%s.%s", "ai", plugin_name), request_analytics_plugin) end - -- log tokens prompt for reports and billing + -- log tokens response for reports and billing local response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) or 0 + if err then + kong.log.warn("failed calculating cost for response tokens: ", err) + end kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens return nil From be9c9b8bc15dd6f681a82ead07fe2d1b4af522f6 Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 16:36:50 +0200 Subject: [PATCH 17/24] fix(ai-proxy): fix lint --- kong/llm/drivers/shared.lua | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index abde9832d46..c1c9775e0ca 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -608,7 +608,7 @@ local function count_prompt(content, tokens_factor) else return nil, "Invalid request format" end - return count + return count, nil end function _M.calculate_cost(query_body, tokens_models, tokens_factor) @@ -642,7 +642,7 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor) -- Round the total cost quantified query_cost = math.floor(query_cost + 0.5) - return query_cost + return query_cost, nil end return _M From 33bc5a731922d5e890bff1e07f3bbc8799b12c41 Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 16:48:12 +0200 Subject: [PATCH 18/24] fix(ai-proxy): fix lint --- kong/llm/drivers/shared.lua | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index c1c9775e0ca..42da0dcdc84 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -416,7 +416,8 @@ function _M.pre_request(conf, request_table) end -- log tokens prompt for reports and billing - local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) or 0 + local prompt_tokens, err + prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) or 0 if err then kong.log.warn("failed calculating cost for prompt tokens: ", err) end @@ -508,7 +509,8 @@ function _M.post_request(conf, response_object) end -- log tokens response for reports and billing - local response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) or 0 + local response_tokens, err + response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) or 0 if err then kong.log.warn("failed calculating cost for response tokens: ", err) end From 2aef21982b11504799672dfa47d9776e6ddb966b Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 16:55:42 +0200 Subject: [PATCH 19/24] fix(ai-proxy): fix lint --- kong/llm/drivers/shared.lua | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 42da0dcdc84..d4e3d6bee76 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -416,8 +416,7 @@ function _M.pre_request(conf, request_table) end -- log tokens prompt for reports and billing - local prompt_tokens, err - prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) or 0 + local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) if err then kong.log.warn("failed calculating cost for prompt tokens: ", err) end @@ -509,8 +508,7 @@ function _M.post_request(conf, response_object) end -- log tokens response for reports and billing - local response_tokens, err - response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) or 0 + local response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) if err then kong.log.warn("failed calculating cost for response tokens: ", err) end From 45a062a91e7db8aef037bb44ce5741de6a9c584c Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 17:30:25 +0200 Subject: [PATCH 20/24] fix(ai-proxy): fix bad request case --- kong/llm/drivers/shared.lua | 2 ++ 1 file changed, 2 insertions(+) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index d4e3d6bee76..19a86492710 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -419,6 +419,7 @@ function _M.pre_request(conf, request_table) local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) if err then kong.log.warn("failed calculating cost for prompt tokens: ", err) + prompt_tokens = 0 end kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens @@ -511,6 +512,7 @@ function _M.post_request(conf, response_object) local response_tokens, err = _M.calculate_cost(response_object, {}, 1.0) if err then kong.log.warn("failed calculating cost for response tokens: ", err) + response_tokens = 0 end kong.ctx.shared.ai_response_tokens = (kong.ctx.shared.ai_response_tokens or 0) + response_tokens From ea61c0dfc01c98efff9b0b89dc9a8e6234fe7dc3 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 25 Apr 2024 16:46:53 +0100 Subject: [PATCH 21/24] fix(ai-proxy): cohere analytics --- kong/llm/drivers/cohere.lua | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index f31608a5b8a..17dd9435e2d 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -250,9 +250,20 @@ local transformers_from = { messages.id = response_table.generation_id local stats = { - completion_tokens = response_table.token_count and response_table.token_count.response_tokens or nil, - prompt_tokens = response_table.token_count and response_table.token_count.prompt_tokens or nil, - total_tokens = response_table.token_count and response_table.token_count.total_tokens or nil, + completion_tokens = response_table.meta + and response_table.meta.billed_units + and response_table.meta.billed_units.output_tokens + or nil, + + prompt_tokens = response_table.meta + and response_table.meta.billed_units + and response_table.meta.billed_units.input_tokens + or nil, + + total_tokens = response_table.meta + and response_table.meta.billed_units + and (response_table.meta.billed_units.output_tokens + response_table.meta.billed_units.input_tokens) + or nil, } messages.usage = stats From 35c0f95d1c38169e99f97c23dad0260309877a17 Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 25 Apr 2024 17:33:06 +0100 Subject: [PATCH 22/24] fix(ai-proxy): missing null checks from rebase; missing conf_m plugin config replacement --- kong/llm/drivers/anthropic.lua | 6 +++++- kong/llm/drivers/azure.lua | 6 ++++-- kong/llm/drivers/cohere.lua | 6 +++++- kong/llm/drivers/openai.lua | 8 +++++--- kong/llm/drivers/shared.lua | 20 ++++++++++++-------- kong/plugins/ai-proxy/handler.lua | 12 ++++++------ 6 files changed, 37 insertions(+), 21 deletions(-) diff --git a/kong/llm/drivers/anthropic.lua b/kong/llm/drivers/anthropic.lua index 8d5a791657f..8eb206b8c1f 100644 --- a/kong/llm/drivers/anthropic.lua +++ b/kong/llm/drivers/anthropic.lua @@ -459,7 +459,11 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(conf.model.options.upstream_url) else parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + parsed_url.path = conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or "/" if not parsed_url.path then return nil, fmt("operation %s is not supported for anthropic provider", conf.route_type) diff --git a/kong/llm/drivers/azure.lua b/kong/llm/drivers/azure.lua index f38c6b4029b..390a96256cb 100644 --- a/kong/llm/drivers/azure.lua +++ b/kong/llm/drivers/azure.lua @@ -108,11 +108,13 @@ function _M.configure_request(conf) ai_shared.upstream_url_format[DRIVER_NAME]:format(conf.model.options.azure_instance, conf.model.options.azure_deployment_id), conf.model.options and conf.model.options.upstream_path - or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or "/" ) parsed_url = socket_url.parse(url) end - + -- if the path is read from a URL capture, 3re that it is valid parsed_url.path = string_gsub(parsed_url.path, "^/*", "/") diff --git a/kong/llm/drivers/cohere.lua b/kong/llm/drivers/cohere.lua index 17dd9435e2d..79aa0ca5010 100644 --- a/kong/llm/drivers/cohere.lua +++ b/kong/llm/drivers/cohere.lua @@ -471,7 +471,11 @@ function _M.configure_request(conf) parsed_url = socket_url.parse(conf.model.options.upstream_url) else parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) - parsed_url.path = ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + parsed_url.path = conf.model.options + and conf.model.options.upstream_path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or "/" if not parsed_url.path then return false, fmt("operation %s is not supported for cohere provider", conf.route_type) diff --git a/kong/llm/drivers/openai.lua b/kong/llm/drivers/openai.lua index c7e8b4a6016..9f4965ece0d 100644 --- a/kong/llm/drivers/openai.lua +++ b/kong/llm/drivers/openai.lua @@ -185,17 +185,19 @@ end -- returns err or nil function _M.configure_request(conf) local parsed_url - + if (conf.model.options and conf.model.options.upstream_url) then parsed_url = socket_url.parse(conf.model.options.upstream_url) else local path = conf.model.options and conf.model.options.upstream_path - or ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or ai_shared.operation_map[DRIVER_NAME][conf.route_type] + and ai_shared.operation_map[DRIVER_NAME][conf.route_type].path + or "/" if not path then return nil, fmt("operation %s is not supported for openai provider", conf.route_type) end - + parsed_url = socket_url.parse(ai_shared.upstream_url_format[DRIVER_NAME]) parsed_url.path = path end diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 19a86492710..117db353ac4 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -377,7 +377,7 @@ function _M.resolve_plugin_conf(kong_request, conf) if #splitted ~= 2 then return nil, "cannot parse expression for field '" .. v .. "'" end - + -- find the request parameter, with the configured name prop_m, err = _M.conf_from_request(kong_request, splitted[1], splitted[2]) if err then @@ -416,12 +416,14 @@ function _M.pre_request(conf, request_table) end -- log tokens prompt for reports and billing - local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) - if err then - kong.log.warn("failed calculating cost for prompt tokens: ", err) - prompt_tokens = 0 + if conf.route_type ~= "preserve" then + local prompt_tokens, err = _M.calculate_cost(request_table, {}, 1.0) + if err then + kong.log.warn("failed calculating cost for prompt tokens: ", err) + prompt_tokens = 0 + end + kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens end - kong.ctx.shared.ai_prompt_tokens = (kong.ctx.shared.ai_prompt_tokens or 0) + prompt_tokens return true, nil end @@ -617,6 +619,10 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor) local query_cost = 0 local err + if not query_body then + return nil, "cannot calculate tokens on empty request" + end + if query_body.choices then -- Calculate the cost based on the content type for _, choice in ipairs(query_body.choices) do @@ -637,8 +643,6 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor) if err then return nil, err end - else - return nil, "No messages or prompt in query" end -- Round the total cost quantified diff --git a/kong/plugins/ai-proxy/handler.lua b/kong/plugins/ai-proxy/handler.lua index 2229c89c7b0..cb2baabc4d7 100644 --- a/kong/plugins/ai-proxy/handler.lua +++ b/kong/plugins/ai-proxy/handler.lua @@ -366,11 +366,11 @@ function _M:access(conf) -- check if the user has asked for a stream, and/or if -- we are forcing all requests to be of streaming type - if request_table.stream or - (conf.model.options and conf.model.options.response_streaming) == "always" then + if request_table and request_table.stream or + (conf_m.model.options and conf_m.model.options.response_streaming) == "always" then -- this condition will only check if user has tried -- to activate streaming mode within their request - if conf.model.options and conf.model.options.response_streaming == "deny" then + if conf_m.model.options and conf_m.model.options.response_streaming == "deny" then return bad_request("response streaming is not enabled for this LLM") end @@ -394,7 +394,7 @@ function _M:access(conf) local ai_driver = require("kong.llm.drivers." .. conf.model.provider) -- execute pre-request hooks for this driver - local ok, err = ai_driver.pre_request(conf, request_table) + local ok, err = ai_driver.pre_request(conf_m, request_table) if not ok then return bad_request(err) end @@ -411,7 +411,7 @@ function _M:access(conf) end -- execute pre-request hooks for "all" drivers before set new body - local ok, err = ai_shared.pre_request(conf, parsed_request_body) + local ok, err = ai_shared.pre_request(conf_m, parsed_request_body) if not ok then return bad_request(err) end @@ -421,7 +421,7 @@ function _M:access(conf) end -- now re-configure the request for this operation type - local ok, err = ai_driver.configure_request(conf) + local ok, err = ai_driver.configure_request(conf_m) if not ok then kong.ctx.shared.skip_response_transformer = true return internal_server_error(err) From 5d19ded6c88847cbaa3a91aa6dce245e9f44781f Mon Sep 17 00:00:00 2001 From: Jack Tysoe Date: Thu, 25 Apr 2024 17:38:03 +0100 Subject: [PATCH 23/24] fix(ai-proxy): another null check --- kong/llm/drivers/shared.lua | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kong/llm/drivers/shared.lua b/kong/llm/drivers/shared.lua index 117db353ac4..f4696f116c1 100644 --- a/kong/llm/drivers/shared.lua +++ b/kong/llm/drivers/shared.lua @@ -626,7 +626,7 @@ function _M.calculate_cost(query_body, tokens_models, tokens_factor) if query_body.choices then -- Calculate the cost based on the content type for _, choice in ipairs(query_body.choices) do - if choice.message.content then + if choice.message and choice.message.content then query_cost = query_cost + (count_words(choice.message.content) * tokens_factor) elseif choice.text then query_cost = query_cost + (count_words(choice.text) * tokens_factor) From 748f3612705297e55332e86d9ad218db359f69b0 Mon Sep 17 00:00:00 2001 From: Antoine Jacquemin Date: Thu, 25 Apr 2024 18:59:02 +0200 Subject: [PATCH 24/24] fix(ai-proxy) fix missing __key__ --- .../39-ai-request-transformer/01-transformer_spec.lua | 7 +++++++ .../40-ai-response-transformer/01-transformer_spec.lua | 1 + 2 files changed, 8 insertions(+) diff --git a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua index 9ccd084d031..51d84a43992 100644 --- a/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua +++ b/spec/03-plugins/39-ai-request-transformer/01-transformer_spec.lua @@ -9,6 +9,7 @@ local PLUGIN_NAME = "ai-request-transformer" local FORMATS = { openai = { + __key__ = "ai-request-transformer", route_type = "llm/v1/chat", model = { name = "gpt-4", @@ -25,6 +26,7 @@ local FORMATS = { }, }, cohere = { + __key__ = "ai-request-transformer", route_type = "llm/v1/chat", model = { name = "command", @@ -41,6 +43,7 @@ local FORMATS = { }, }, anthropic = { + __key__ = "ai-request-transformer", route_type = "llm/v1/chat", model = { name = "claude-2.1", @@ -57,6 +60,7 @@ local FORMATS = { }, }, azure = { + __key__ = "ai-request-transformer", route_type = "llm/v1/chat", model = { name = "gpt-4", @@ -73,6 +77,7 @@ local FORMATS = { }, }, llama2 = { + __key__ = "ai-request-transformer", route_type = "llm/v1/chat", model = { name = "llama2", @@ -90,6 +95,7 @@ local FORMATS = { }, }, mistral = { + __key__ = "ai-request-transformer", route_type = "llm/v1/chat", model = { name = "mistral", @@ -110,6 +116,7 @@ local FORMATS = { local OPENAI_NOT_JSON = { route_type = "llm/v1/chat", + __key__ = "ai-request-transformer", model = { name = "gpt-4", provider = "openai", diff --git a/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua b/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua index 6409fbcafef..d436ad53644 100644 --- a/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua +++ b/spec/03-plugins/40-ai-response-transformer/01-transformer_spec.lua @@ -8,6 +8,7 @@ local MOCK_PORT = helpers.get_available_port() local PLUGIN_NAME = "ai-response-transformer" local OPENAI_INSTRUCTIONAL_RESPONSE = { + __key__ = "ai-response-transformer", route_type = "llm/v1/chat", model = { name = "gpt-4",