Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAI Streaming Support, CompletionStream protobuf, and ctransformers model backend #129

Merged
merged 5 commits into from
Jun 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 95 additions & 41 deletions api/backends/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package openai
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"log"
"reflect"
"strings"
Expand Down Expand Up @@ -310,56 +312,108 @@ func (o *OpenAIHandler) complete(c *gin.Context) {
if conn == nil {
return
}

logit := make(map[string]int32)
for k, v := range input.LogitBias {
logit[k] = int32(v)
}

client := generate.NewCompletionServiceClient(conn)

id, _ := uuid.NewRandom()
if input.N == 0 {
input.N = 1
}
resp := openai.CompletionResponse{
ID: id.String(),
Created: time.Now().Unix(),
Model: input.Model,
Choices: make([]openai.CompletionChoice, input.N),
}

for i := 0; i < input.N; i++ {
// Implement the completion logic here, using the data from `input`
response, err := client.Complete(c.Request.Context(), &generate.CompletionRequest{
Prompt: input.Prompt.(string),
Suffix: input.Suffix,
MaxTokens: int32(input.MaxTokens),
Temperature: input.Temperature,
TopP: input.TopP,
Stream: input.Stream,
Logprobs: int32(input.LogProbs),
Echo: input.Echo,
Stop: input.Stop, // Wrong type here...
PresencePenalty: input.PresencePenalty,
FrequencePenalty: input.FrequencyPenalty,
BestOf: int32(input.BestOf),
LogitBias: logit, // Wrong type here
if input.Stream {
chanStream := make(chan *generate.CompletionResponse, 10)
client := generate.NewCompletionStreamServiceClient(conn)
stream, err := client.CompleteStream(context.Background(), &generate.CompletionRequest{
Prompt: input.Prompt.(string),
MaxTokens: int32(input.MaxTokens),
Temperature: input.Temperature,
})

if err != nil {
log.Printf("500: Error completing via backend(%v): %v\n", input.Model, err)
c.JSON(500, err)
return
}
choice := openai.CompletionChoice{
Text: strings.TrimPrefix(response.GetCompletion(), input.Prompt.(string)),
FinishReason: response.GetFinishReason(),
Index: i,

go func() {
defer close(chanStream)
for {
cResp, err := stream.Recv()
if err == io.EOF {
break
}
chanStream <- cResp
}
}()
c.Stream(func(w io.Writer) bool {
if msg, ok := <-chanStream; ok {

// OpenAI places a space in between the data key and payload in HTTP. So, I guess we're bug-for-bug compatible.
res, err := json.Marshal(openai.CompletionResponse{
ID: id.String(),
Created: time.Now().Unix(),
Model: input.Model,
Object: "text_completion",
Choices: []openai.CompletionChoice{
{
Index: 0,
Text: msg.GetCompletion(),
},
},
})
if err != nil {
return false
}
c.SSEvent("", fmt.Sprintf(" %s", res))
return true
}
c.SSEvent("", " [DONE]")
return false
})
} else {

logit := make(map[string]int32)
for k, v := range input.LogitBias {
logit[k] = int32(v)
}

client := generate.NewCompletionServiceClient(conn)

if input.N == 0 {
input.N = 1
}
resp := openai.CompletionResponse{
ID: id.String(),
Created: time.Now().Unix(),
Model: input.Model,
Choices: make([]openai.CompletionChoice, input.N),
}
resp.Choices[i] = choice
}

c.JSON(200, resp)
for i := 0; i < input.N; i++ {
// Implement the completion logic here, using the data from `input`
response, err := client.Complete(c.Request.Context(), &generate.CompletionRequest{
Prompt: input.Prompt.(string),
Suffix: input.Suffix,
MaxTokens: int32(input.MaxTokens),
Temperature: input.Temperature,
TopP: input.TopP,
Stream: input.Stream,
Logprobs: int32(input.LogProbs),
Echo: input.Echo,
Stop: input.Stop, // Wrong type here...
PresencePenalty: input.PresencePenalty,
FrequencePenalty: input.FrequencyPenalty,
BestOf: int32(input.BestOf),
LogitBias: logit, // Wrong type here
})
if err != nil {
log.Printf("500: Error completing via backend(%v): %v\n", input.Model, err)
c.JSON(500, err)
return
}
choice := openai.CompletionChoice{
Text: strings.TrimPrefix(response.GetCompletion(), input.Prompt.(string)),
FinishReason: response.GetFinishReason(),
Index: i,
}
resp.Choices[i] = choice
}

c.JSON(200, resp)
}
// Send the response
}

Expand Down
9 changes: 4 additions & 5 deletions api/models2.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
[stablelm-3b]
[ctransformers]

[stablelm-3b.metadata]
[ctransformers.metadata]
owned_by = 'Defense Unicorns'
permission = []
description = 'Stablelm-3b tuned'
description = 'ctransformers tuned'
tasks = ["completion"]

[stablelm-3b.network]
[ctransformers.network]
url = 'localhost:50051'
type = 'gRPC'

3 changes: 3 additions & 0 deletions leapfrogai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
CompletionService,
CompletionServiceServicer,
CompletionServiceStub,
CompletionStreamService,
CompletionStreamServiceServicer,
CompletionStreamServiceStub,
)
from .name.name_pb2 import NameResponse
from .name.name_pb2_grpc import NameService, NameServiceServicer, NameServiceStub
Expand Down
121 changes: 73 additions & 48 deletions leapfrogai/audio/audio_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ def __init__(self, channel):
channel: A grpc.Channel.
"""
self.Translate = channel.stream_unary(
'/audio.Audio/Translate',
request_serializer=audio_dot_audio__pb2.AudioRequest.SerializeToString,
response_deserializer=audio_dot_audio__pb2.AudioResponse.FromString,
)
"/audio.Audio/Translate",
request_serializer=audio_dot_audio__pb2.AudioRequest.SerializeToString,
response_deserializer=audio_dot_audio__pb2.AudioResponse.FromString,
)
self.Transcribe = channel.stream_unary(
'/audio.Audio/Transcribe',
request_serializer=audio_dot_audio__pb2.AudioRequest.SerializeToString,
response_deserializer=audio_dot_audio__pb2.AudioResponse.FromString,
)
"/audio.Audio/Transcribe",
request_serializer=audio_dot_audio__pb2.AudioRequest.SerializeToString,
response_deserializer=audio_dot_audio__pb2.AudioResponse.FromString,
)


class AudioServicer(object):
Expand All @@ -32,68 +32,93 @@ class AudioServicer(object):
def Translate(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")

def Transcribe(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")


def add_AudioServicer_to_server(servicer, server):
rpc_method_handlers = {
'Translate': grpc.stream_unary_rpc_method_handler(
servicer.Translate,
request_deserializer=audio_dot_audio__pb2.AudioRequest.FromString,
response_serializer=audio_dot_audio__pb2.AudioResponse.SerializeToString,
),
'Transcribe': grpc.stream_unary_rpc_method_handler(
servicer.Transcribe,
request_deserializer=audio_dot_audio__pb2.AudioRequest.FromString,
response_serializer=audio_dot_audio__pb2.AudioResponse.SerializeToString,
),
"Translate": grpc.stream_unary_rpc_method_handler(
servicer.Translate,
request_deserializer=audio_dot_audio__pb2.AudioRequest.FromString,
response_serializer=audio_dot_audio__pb2.AudioResponse.SerializeToString,
),
"Transcribe": grpc.stream_unary_rpc_method_handler(
servicer.Transcribe,
request_deserializer=audio_dot_audio__pb2.AudioRequest.FromString,
response_serializer=audio_dot_audio__pb2.AudioResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'audio.Audio', rpc_method_handlers)
"audio.Audio", rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
# This class is part of an EXPERIMENTAL API.
class Audio(object):
"""Missing associated documentation comment in .proto file."""

@staticmethod
def Translate(request_iterator,
def Translate(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_unary(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/audio.Audio/Translate',
"/audio.Audio/Translate",
audio_dot_audio__pb2.AudioRequest.SerializeToString,
audio_dot_audio__pb2.AudioResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)

@staticmethod
def Transcribe(request_iterator,
def Transcribe(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_unary(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/audio.Audio/Transcribe',
"/audio.Audio/Transcribe",
audio_dot_audio__pb2.AudioRequest.SerializeToString,
audio_dot_audio__pb2.AudioResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
)
Loading