Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai-proxy): add streaming support and transformers #12792

Merged
merged 17 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
**AI-Proxy**: add support for streaming event-by-event responses back to client on supported providers
scope: Plugin
type: feature
4 changes: 2 additions & 2 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
5 changes: 2 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand All @@ -82,7 +82,6 @@ end

-- returns err or nil
function _M.configure_request(conf)

local parsed_url

if conf.model.options.upstream_url then
Expand Down
126 changes: 121 additions & 5 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,119 @@ local table_new = require("table.new")
local DRIVER_NAME = "cohere"
--

local function handle_stream_event(event_string, model_info, route_type)
local metadata

-- discard empty frames, it should either be a random new line, or comment
if #event_string < 1 then
return
end

local event, err = cjson.decode(event_string)
if err then
return nil, "failed to decode event frame from cohere: " .. err, nil
end

local new_event

if event.event_type == "stream-start" then
kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id

-- ignore the rest of this one
new_event = {
choices = {
[1] = {
delta = {
content = "",
role = "assistant",
},
index = 0,
},
},
id = event.generation_id,
model = model_info.name,
object = "chat.completion.chunk",
}

elseif event.event_type == "text-generation" then
-- this is a token
if route_type == "stream/llm/v1/chat" then
new_event = {
choices = {
[1] = {
delta = {
content = event.text or "",
},
index = 0,
finish_reason = cjson.null,
logprobs = cjson.null,
},
},
id = kong
and kong.ctx
and kong.ctx.plugin
and kong.ctx.plugin.ai_proxy_cohere_stream_id,
model = model_info.name,
object = "chat.completion.chunk",
}

elseif route_type == "stream/llm/v1/completions" then
new_event = {
choices = {
[1] = {
text = event.text or "",
index = 0,
finish_reason = cjson.null,
logprobs = cjson.null,
},
},
id = kong
and kong.ctx
and kong.ctx.plugin
and kong.ctx.plugin.ai_proxy_cohere_stream_id,
model = model_info.name,
object = "text_completion",
}

end

elseif event.event_type == "stream-end" then
-- return a metadata object, with a null event
metadata = {
-- prompt_tokens = event.response.token_count.prompt_tokens,
-- completion_tokens = event.response.token_count.response_tokens,

completion_tokens = event.response
and event.response.meta
and event.response.meta.billed_units
and event.response.meta.billed_units.output_tokens
or
event.response
and event.response.token_count
and event.response.token_count.response_tokens
or 0,

prompt_tokens = event.response
and event.response.meta
and event.response.meta.billed_units
and event.response.meta.billed_units.input_tokens
or
event.response
and event.response.token_count
and event.token_count.prompt_tokens
or 0,
}

end

if new_event then
new_event = cjson.encode(new_event)
return new_event, nil, metadata
else
return nil, nil, metadata -- caller code will handle "unrecognised" event types
end
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model)
request_table.model = model.name
Expand Down Expand Up @@ -193,7 +306,7 @@ local transformers_from = {

if response_table.prompt and response_table.generations then
-- this is a "co.generate"

for i, v in ipairs(response_table.generations) do
prompt.choices[i] = {
index = (i-1),
Expand Down Expand Up @@ -243,6 +356,9 @@ local transformers_from = {

return cjson.encode(prompt)
end,

["stream/llm/v1/chat"] = handle_stream_event,
["stream/llm/v1/completions"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand All @@ -253,7 +369,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info)
local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
model_info.provider,
Expand All @@ -262,7 +378,7 @@ function _M.from_format(response_string, model_info, route_type)
)
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -344,13 +460,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
14 changes: 8 additions & 6 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ local transformers_from = {
["llm/v1/completions/raw"] = from_raw,
["llm/v1/chat/ollama"] = ai_shared.from_ollama,
["llm/v1/completions/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama,
}

local transformers_to = {
Expand All @@ -155,8 +157,8 @@ function _M.from_format(response_string, model_info, route_type)
if not transformers_from[transformer_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, transformer_type)
end
local ok, response_string, err = pcall(

local ok, response_string, err, metadata = pcall(
transformers_from[transformer_type],
response_string,
model_info,
Expand All @@ -166,7 +168,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error")
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -217,13 +219,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down Expand Up @@ -265,7 +267,7 @@ function _M.configure_request(conf)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
Expand Down
6 changes: 4 additions & 2 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ local DRIVER_NAME = "mistral"
local transformers_from = {
["llm/v1/chat/ollama"] = ai_shared.from_ollama,
["llm/v1/completions/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama,
}

local transformers_to = {
Expand Down Expand Up @@ -104,13 +106,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
25 changes: 21 additions & 4 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ local socket_url = require "socket.url"
local DRIVER_NAME = "openai"
--

local function handle_stream_event(event_string)
if #event_string > 0 then
local lbl, val = event_string:match("(%w*): (.*)")

if lbl == "data" then
return val
end
end

return nil
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model, max_tokens, temperature, top_p)
-- if user passed a prompt as a chat, transform it to a chat message
Expand All @@ -29,8 +41,9 @@ local transformers_to = {
max_tokens = max_tokens,
temperature = temperature,
top_p = top_p,
stream = request_table.stream or false,
}

return this, "application/json", nil
end,

Expand All @@ -40,6 +53,7 @@ local transformers_to = {
model = model,
max_tokens = max_tokens,
temperature = temperature,
stream = request_table.stream or false,
}

return this, "application/json", nil
Expand All @@ -52,7 +66,7 @@ local transformers_from = {
if err then
return nil, "'choices' not in llm/v1/chat response"
end

if response_object.choices then
return response_string, nil
else
Expand All @@ -72,6 +86,9 @@ local transformers_from = {
return nil, "'choices' not in llm/v1/completions response"
end
end,

["stream/llm/v1/chat"] = handle_stream_event,
["stream/llm/v1/completions"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand Down Expand Up @@ -155,13 +172,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
flrgh marked this conversation as resolved.
Show resolved Hide resolved
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
Loading
Loading