Skip to content

Commit

Permalink
Deprecate and refactor
Browse files Browse the repository at this point in the history
- Deprecate `ErrO1BetaLimitationsLogprobs` and `ErrO1BetaLimitationsOther`

- Implement `validationRequestForReasoningModels`, which works on both o1 & o3, and has per-model-type restrictions on functionality (eg, o3 class are allowed function calls and system messages, o1 isn't)
  • Loading branch information
rorymalcolm committed Feb 1, 2025
1 parent 24e00aa commit 73c6ae6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 61 deletions.
6 changes: 1 addition & 5 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -392,11 +392,7 @@ func (c *Client) CreateChatCompletion(
return
}

if err = validateRequestForO1Models(request); err != nil {
return
}

if err = validateRequestForO3Models(request); err != nil {
if err = valideOSeriesRequest(request); err != nil {
return
}

Expand Down
6 changes: 1 addition & 5 deletions chat_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,7 @@ func (c *Client) CreateChatCompletionStream(
}

request.Stream = true
if err = validateRequestForO1Models(request); err != nil {
return
}

if err = validateRequestForO3Models(request); err != nil {
if err = valideOSeriesRequest(request); err != nil {
return
}

Expand Down
12 changes: 6 additions & 6 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
LogProbs: true,
Model: openai.O1Preview,
},
expectedError: openai.ErrO1BetaLimitationsLogprobs,
expectedError: openai.ErrO3BetaLimitationsLogprobs,
},
{
name: "message_type_unsupported",
Expand Down Expand Up @@ -155,7 +155,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
Temperature: float32(2),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrO3BetaLimitationsOther,
},
{
name: "set_top_unsupported",
Expand All @@ -173,7 +173,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
Temperature: float32(1),
TopP: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrO3BetaLimitationsOther,
},
{
name: "set_n_unsupported",
Expand All @@ -192,7 +192,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
TopP: float32(1),
N: 2,
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrO3BetaLimitationsOther,
},
{
name: "set_presence_penalty_unsupported",
Expand All @@ -209,7 +209,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
PresencePenalty: float32(1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrO3BetaLimitationsOther,
},
{
name: "set_frequency_penalty_unsupported",
Expand All @@ -226,7 +226,7 @@ func TestO1ModelsChatCompletionsBetaLimitations(t *testing.T) {
},
FrequencyPenalty: float32(0.1),
},
expectedError: openai.ErrO1BetaLimitationsOther,
expectedError: openai.ErrO3BetaLimitationsOther,
},
}

Expand Down
68 changes: 23 additions & 45 deletions completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@ var (
)

var (
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
ErrO1BetaLimitationsMessageTypes = errors.New("this model has beta-limitations, user and assistant messages only, system messages are not supported") //nolint:lll
ErrO1BetaLimitationsTools = errors.New("this model has beta-limitations, tools, function calling, and response format parameters are not supported") //nolint:lll
// Deprecated: use Err03BetaLimitations* instead

Check failure on line 19 in completion.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
ErrO1BetaLimitationsLogprobs = errors.New("this model has beta-limitations, logprobs not supported") //nolint:lll
ErrO1BetaLimitationsOther = errors.New("this model has beta-limitations, temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0") //nolint:lll
)

var (
Expand Down Expand Up @@ -207,9 +208,13 @@ var availableMessageRoleForO1Models = map[string]struct{}{
ChatMessageRoleAssistant: {},
}

// validateRequestForO1Models checks for deprecated fields of OpenAI models.
func validateRequestForO1Models(request ChatCompletionRequest) error {
if _, found := O1SeriesModels[request.Model]; !found {
// valideOSeriesRequest checks for fields in requests which are not allowed for o-series models

Check failure on line 211 in completion.go

View workflow job for this annotation

GitHub Actions / Sanity check

Comment should end in a period (godot)
// this includes: o1 does not support function calling, or system messages
func valideOSeriesRequest(request ChatCompletionRequest) error {

Check failure on line 213 in completion.go

View workflow job for this annotation

GitHub Actions / Sanity check

cognitive complexity 24 of func `valideOSeriesRequest` is high (> 20) (gocognit)
_, o1Series := O1SeriesModels[request.Model]
_, o3Series := O3SeriesModels[request.Model]

if !o1Series && !o3Series {
return nil
}

Expand All @@ -219,54 +224,27 @@ func validateRequestForO1Models(request ChatCompletionRequest) error {

// Logprobs: not supported.
if request.LogProbs {
return ErrO1BetaLimitationsLogprobs
return ErrO3BetaLimitationsLogprobs
}

// Message types: user and assistant messages only, system messages are not supported.
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
if o1Series {
for _, m := range request.Messages {
if _, found := availableMessageRoleForO1Models[m.Role]; !found {
return ErrO1BetaLimitationsMessageTypes
}
}
}

// Tools: tools, function calling, and response format parameters are not supported
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
if o1Series {
for _, t := range request.Tools {
if _, found := unsupportedToolsForO1Models[t.Type]; found {
return ErrO1BetaLimitationsTools
}
}
}

// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
if request.Temperature > 0 && request.Temperature != 1 {
return ErrO1BetaLimitationsOther
}
if request.TopP > 0 && request.TopP != 1 {
return ErrO1BetaLimitationsOther
}
if request.N > 0 && request.N != 1 {
return ErrO1BetaLimitationsOther
}
if request.PresencePenalty > 0 {
return ErrO1BetaLimitationsOther
}
if request.FrequencyPenalty > 0 {
return ErrO1BetaLimitationsOther
}

return nil
}

// validateRequestForO3Models checks for deprecated fields of OpenAI models.
func validateRequestForO3Models(request ChatCompletionRequest) error {
if _, found := O3SeriesModels[request.Model]; !found {
return nil
}

// Logprobs: not supported.
if request.LogProbs {
return ErrO3BetaLimitationsLogprobs
}

// Other: temperature, top_p and n are fixed at 1, while presence_penalty and frequency_penalty are fixed at 0.
if request.Temperature > 0 && request.Temperature != 1 {
return ErrO3BetaLimitationsOther
Expand Down

0 comments on commit 73c6ae6

Please sign in to comment.