Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move model capabilitiy verification out of openAIChat.m, for maintain… #37

Merged
merged 2 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions +llms/+openai/models.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function models = models
%MODELS - supported OpenAI models

% Copyright 2024 The MathWorks, Inc.
models = [...
"gpt-4o","gpt-4o-2024-05-13",...
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
"gpt-4","gpt-4-0613", ...
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
"gpt-3.5-turbo-1106",...
];
end
13 changes: 13 additions & 0 deletions +llms/+openai/validateMessageSupported.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
function validateMessageSupported(message, model);
%validateMessageSupported - check that message is supported by model

% Copyright 2024 The MathWorks, Inc.

% only certain models support image generation
if iscell(message.content) && any(cellfun(@(x) isfield(x,"image_url"), message.content))
if ~ismember(model,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
error("llms:invalidContentTypeForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", model));
end
end
end
16 changes: 16 additions & 0 deletions +llms/+openai/validateResponseFormat.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
function validateResponseFormat(format,model)
%validateResponseFormat - validate requested response format is available for selected model
% Not all OpenAI models support JSON output

% Copyright 2024 The MathWorks, Inc.

if format == "json"
if ismember(model,["gpt-4","gpt-4-0613"])
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", model));
else
warning("llms:warningJsonInstruction", ...
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
end
end
end
Binary file not shown.
31 changes: 8 additions & 23 deletions openAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,7 @@
arguments
systemPrompt {llms.utils.mustBeTextOrEmpty} = []
nvp.Tools (1,:) {mustBeA(nvp.Tools, "openAIFunction")} = openAIFunction.empty
nvp.ModelName (1,1) string {mustBeMember(nvp.ModelName,[...
"gpt-4o","gpt-4o-2024-05-13",...
"gpt-4-turbo","gpt-4-turbo-2024-04-09",...
"gpt-4","gpt-4-0613", ...
"gpt-3.5-turbo","gpt-3.5-turbo-0125", ...
"gpt-3.5-turbo-1106",...
])} = "gpt-3.5-turbo"
nvp.ModelName (1,1) string {mustBeModel} = "gpt-3.5-turbo"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still get tab completion?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to create a json file for tab completion. That should then also enable completion for other NVPs.

nvp.Temperature {mustBeValidTemperature} = 1
nvp.TopProbabilityMass {mustBeValidTopP} = 1
nvp.StopSequences {mustBeValidStop} = {}
Expand Down Expand Up @@ -160,16 +154,8 @@
this.StopSequences = nvp.StopSequences;

% ResponseFormat is only supported in the latest models only
if nvp.ResponseFormat == "json"
if ismember(this.ModelName,["gpt-4","gpt-4-0613"])
error("llms:invalidOptionAndValueForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidOptionAndValueForModel", "ResponseFormat", "json", this.ModelName));
else
warning("llms:warningJsonInstruction", ...
llms.utils.errorMessageCatalog.getMessage("llms:warningJsonInstruction"))
end

end
llms.openai.validateResponseFormat(nvp.ResponseFormat, this.ModelName);
this.ResponseFormat = nvp.ResponseFormat;

this.PresencePenalty = nvp.PresencePenalty;
this.FrequencyPenalty = nvp.FrequencyPenalty;
Expand Down Expand Up @@ -219,12 +205,7 @@
messagesStruct = messages.Messages;
end

if iscell(messagesStruct{end}.content) && any(cellfun(@(x) isfield(x,"image_url"), messagesStruct{end}.content))
if ~ismember(this.ModelName,["gpt-4-turbo","gpt-4-turbo-2024-04-09","gpt-4o","gpt-4o-2024-05-13"])
error("llms:invalidContentTypeForModel", ...
llms.utils.errorMessageCatalog.getMessage("llms:invalidContentTypeForModel", "Image content", this.ModelName));
end
end
llms.openai.validateMessageSupported(messagesStruct{end}, this.ModelName);

if ~isempty(this.SystemPrompt)
messagesStruct = horzcat(this.SystemPrompt, messagesStruct);
Expand Down Expand Up @@ -334,3 +315,7 @@ function mustBeIntegerOrEmpty(value)
mustBeInteger(value)
end
end

function mustBeModel(model)
mustBeMember(model,llms.openai.models);
end
169 changes: 168 additions & 1 deletion tests/topenAIChat.m
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ function saveEnvVar(testCase)
end

properties(TestParameter)
ValidConstructorInput = iGetValidConstructorInput();
InvalidConstructorInput = iGetInvalidConstructorInput();
InvalidGenerateInput = iGetInvalidGenerateInput();
InvalidValuesSetters = iGetInvalidValuesSetters();
Expand Down Expand Up @@ -65,6 +66,21 @@ function constructChatWithAllNVP(testCase)
testCase.verifyEqual(chat.PresencePenalty, presenceP);
end

function validConstructorCalls(testCase,ValidConstructorInput)
if isempty(ValidConstructorInput.ExpectedWarning)
chat = testCase.verifyWarningFree(...
@() openAIChat(ValidConstructorInput.Input{:}));
else
chat = testCase.verifyWarning(...
@() openAIChat(ValidConstructorInput.Input{:}), ...
ValidConstructorInput.ExpectedWarning);
end
properties = ValidConstructorInput.VerifyProperties;
for prop=string(fieldnames(properties)).'
testCase.verifyEqual(chat.(prop),properties.(prop),"Property " + prop);
end
end

function verySmallTimeOutErrors(testCase)
chat = openAIChat(TimeOut=0.0001, ApiKey="false-key");

Expand Down Expand Up @@ -126,7 +142,6 @@ function noStopSequencesNoMaxNumTokens(testCase)
end

function createOpenAIChatWithStreamFunc(testCase)

function seen = sf(str)
persistent data;
if isempty(data)
Expand Down Expand Up @@ -275,6 +290,158 @@ function createOpenAIChatWithOpenAIKeyLatestModel(testCase)
"Error", "MATLAB:notGreaterEqual"));
end

function validConstructorInput = iGetValidConstructorInput()
% while it is valid to provide the key via an environment variable,
% this test set does not use that, for easier setup
validFunction = openAIFunction("funName");
validConstructorInput = struct( ...
"JustKey", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key"}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"SystemPrompt", struct( ...
"Input",{{"system prompt","ApiKey","this-is-not-a-real-key"}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {{struct("role","system","content","system prompt")}}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"Temperature", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","Temperature",2}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {2}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"TopProbabilityMass", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","TopProbabilityMass",0.2}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {0.2}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"StopSequences", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","StopSequences",["foo","bar"]}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {["foo","bar"]}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"PresencePenalty", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","PresencePenalty",0.1}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0.1}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"FrequencyPenalty", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","FrequencyPenalty",0.1}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0.1}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"TimeOut", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","TimeOut",0.1}}, ...
"ExpectedWarning", '', ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {0.1}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"text"} ...
) ...
), ...
"ResponseFormat", struct( ...
"Input",{{"ApiKey","this-is-not-a-real-key","ResponseFormat","json"}}, ...
"ExpectedWarning", "llms:warningJsonInstruction", ...
"VerifyProperties", struct( ...
"Temperature", {1}, ...
"TopProbabilityMass", {1}, ...
"StopSequences", {{}}, ...
"PresencePenalty", {0}, ...
"FrequencyPenalty", {0}, ...
"TimeOut", {10}, ...
"FunctionNames", {[]}, ...
"ModelName", {"gpt-3.5-turbo"}, ...
"SystemPrompt", {[]}, ...
"ResponseFormat", {"json"} ...
) ...
) ...
);
end

function invalidConstructorInput = iGetInvalidConstructorInput()
validFunction = openAIFunction("funName");
invalidConstructorInput = struct( ...
Expand Down