Skip to content

Commit

Permalink
Add GetCompletionsSSE and supporting methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mikekistler committed Apr 23, 2023
1 parent 7445e2a commit 5508952
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions sdk/cognitiveservices/azopenai/custom_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,18 @@ package azopenai
// this file contains handwritten additions to the generated code

import (
"bufio"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"strings"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
)

const (
Expand Down Expand Up @@ -58,3 +67,69 @@ func NewClientWithKeyCredential(endpoint string, credential KeyCredential, optio
}
return &Client{endpoint: endpoint + "/openai", internal: azcoreClient}, nil
}

// Support for SSE

// Generated from API version 2022-12-01
// - options - ClientGetCompletionsOptions contains the optional parameters for the Client.GetCompletions method.
func (client *Client) GetCompletionsSSE(ctx context.Context, deploymentID string, body CompletionRequest, options *ClientGetCompletionsOptions) (*http.Response, error) {
body.Stream = to.Ptr(true)
req, err := client.getCompletionsCreateRequest(ctx, deploymentID, body, options)
if err != nil {
return nil, err
}
resp, err := client.internal.Pipeline().Do(req)
if err != nil {
return nil, err
}
if !runtime.HasStatusCode(resp, http.StatusOK) {
return nil, runtime.NewResponseError(resp)
}
return resp, nil
}

type EventReader[T any] struct {
reader io.Reader // Required for Closing
scanner *bufio.Scanner
}

func NewEventReader[T any](r io.Reader) *EventReader[T] {
return &EventReader[T]{reader: r, scanner: bufio.NewScanner(r)}
}

func (er *EventReader[T]) Read() (T, error) {
// https://html.spec.whatwg.org/multipage/server-sent-events.html
for er.scanner.Scan() { // Scan while no error
line := er.scanner.Text() // Get the line & interpret the event stream:

if line == "" || line[0] == ':' { // If the line is blank or is a comment, skip it
continue
}

if strings.Contains(line, ":") { // If the line contains a U+003A COLON character (:), process the field
tokens := strings.SplitN(line, ":", 2)
tokens[0], tokens[1] = strings.TrimSpace(tokens[0]), strings.TrimSpace(tokens[1])
var data T
switch tokens[0] {
case "data": // return the deserialized JSON object
if tokens[1] == "[DONE]" { // If data is [DONE], end of stream was reached
return data, io.EOF
}
//fmt.Println(tokens[1])
err := json.Unmarshal([]byte(tokens[1]), &data)
return data, err

default: // Any other event type is an unexpected
return data, errors.New("Unexpected event type: " + tokens[0])
}
// Unreachable
}
}
return *new(T), er.scanner.Err()
}

func (er *EventReader[T]) Close() {
if closer, ok := er.reader.(io.Closer); ok {
closer.Close()
}
}

0 comments on commit 5508952

Please sign in to comment.