Skip to content

Commit

Permalink
Merge pull request #10 from matlab-deep-learning/dev-fix-streaming-bug
Browse files Browse the repository at this point in the history
Support function calls in streaming
  • Loading branch information
debymf authored Feb 28, 2024
2 parents ab154ef + 0b4b546 commit efa591b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 7 deletions.
16 changes: 14 additions & 2 deletions +llms/+internal/callOpenAIChatAPI.m
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,20 @@
if isempty(nvp.StreamFun)
message = response.Body.Data.choices(1).message;
else
message = struct("role", "assistant", ...
"content", streamedText);
pat = '{"' + wildcardPattern + '":';
if contains(streamedText,pat)
s = jsondecode(streamedText);
if contains(s.function.arguments,pat)
prompt = jsondecode(s.function.arguments);
s.function.arguments = prompt;
end
message = struct("role", "assistant", ...
"content",[], ...
"tool_calls",jsondecode(streamedText));
else
message = struct("role", "assistant", ...
"content", streamedText);
end
end
if isfield(message, "tool_choice")
text = "";
Expand Down
40 changes: 35 additions & 5 deletions +llms/+stream/responseStreamer.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,44 @@
str = erase(str,"data: ");

for i = 1:length(str)
json = jsondecode(str{i});
if strcmp(json.choices.finish_reason,'stop')
if strcmp(str{i},'[DONE]')
stop = true;
return
else
txt = json.choices.delta.content;
this.StreamFun(txt);
this.ResponseText = [this.ResponseText txt];
try
json = jsondecode(str{i});
catch ME
errID = 'llms:stream:responseStreamer:InvalidInput';
msg = "Input does not have the expected json format. " + str{i};
causeException = MException(errID,msg);
ME = addCause(ME,causeException);
rethrow(ME)
end
if ischar(json.choices.finish_reason) && ismember(json.choices.finish_reason,["stop","tool_calls"])
stop = true;
return
else
if isfield(json.choices.delta,"tool_calls")
if isfield(json.choices.delta.tool_calls,"id")
id = json.choices.delta.tool_calls.id;
type = json.choices.delta.tool_calls.type;
fcn = json.choices.delta.tool_calls.function;
s = struct('id',id,'type',type,'function',fcn);
txt = jsonencode(s);
else
s = jsondecode(this.ResponseText);
args = json.choices.delta.tool_calls.function.arguments;
s.function.arguments = [s.function.arguments args];
txt = jsonencode(s);
end
this.StreamFun('');
this.ResponseText = txt;
else
txt = json.choices.delta.content;
this.StreamFun(txt);
this.ResponseText = [this.ResponseText txt];
end
end
end
end
end
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
*.asv
*.mat
startup.m
papers_to_read.csv
data/*
Binary file modified examples/ExampleParallelFunctionCalls.mlx
Binary file not shown.

0 comments on commit efa591b

Please sign in to comment.