Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: untangle pkg/grpc and core/schema for Transcription #3419

Merged
merged 5 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ endif

backend-assets/grpc/whisper: sources/whisper.cpp sources/whisper.cpp/libwhisper.a backend-assets/grpc
CGO_LDFLAGS="$(CGO_LDFLAGS) $(CGO_LDFLAGS_WHISPER)" C_INCLUDE_PATH="$(CURDIR)/sources/whisper.cpp/include:$(CURDIR)/sources/whisper.cpp/ggml/include" LIBRARY_PATH=$(CURDIR)/sources/whisper.cpp \
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/
$(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o backend-assets/grpc/whisper ./backend/go/transcribe/whisper
ifneq ($(UPX),)
$(UPX) backend-assets/grpc/whisper
endif
Expand Down
104 changes: 0 additions & 104 deletions backend/go/transcribe/transcript.go

This file was deleted.

26 changes: 0 additions & 26 deletions backend/go/transcribe/whisper.go

This file was deleted.

File renamed without changes.
105 changes: 105 additions & 0 deletions backend/go/transcribe/whisper/whisper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package main
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually a rename of backend/go/transcribe/whisper.go - it's been modified enough that git considers it a new file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transcript.go was also merged into this file to avoid copying locks - it was very whisper specific already


// This is a wrapper to statisfy the GRPC service interface
// It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc)
import (
"os"
"path/filepath"

"github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper"
"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
)

type Whisper struct {
base.SingleThread
whisper whisper.Model
}

func (sd *Whisper) Load(opts *pb.ModelOptions) error {
// Note: the Model here is a path to a directory containing the model files
w, err := whisper.New(opts.ModelFile)
sd.whisper = w
return err
}

func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {

dir, err := os.MkdirTemp("", "whisper")
if err != nil {
return pb.TranscriptResult{}, err
}
defer os.RemoveAll(dir)

convertedPath := filepath.Join(dir, "converted.wav")

if err := utils.AudioToWav(opts.Dst, convertedPath); err != nil {
return pb.TranscriptResult{}, err
}

// Open samples
fh, err := os.Open(convertedPath)
if err != nil {
return pb.TranscriptResult{}, err
}
defer fh.Close()

// Read samples
d := wav.NewDecoder(fh)
buf, err := d.FullPCMBuffer()
if err != nil {
return pb.TranscriptResult{}, err
}

data := buf.AsFloat32Buffer().Data

// Process samples
context, err := sd.whisper.NewContext()
if err != nil {
return pb.TranscriptResult{}, err

}

context.SetThreads(uint(opts.Threads))

if opts.Language != "" {
context.SetLanguage(opts.Language)
} else {
context.SetLanguage("auto")
}

if opts.Translate {
context.SetTranslate(true)
}

if err := context.Process(data, nil, nil); err != nil {
return pb.TranscriptResult{}, err
}

segments := []*pb.TranscriptSegment{}
text := ""
for {
s, err := context.NextSegment()
if err != nil {
break
}

var tokens []int32
for _, t := range s.Tokens {
tokens = append(tokens, int32(t.Id))
}

segment := &pb.TranscriptSegment{Id: int32(s.Num), Text: s.Text, Start: int64(s.Start), End: int64(s.End), Tokens: tokens}
segments = append(segments, segment)

text += s.Text
}

return pb.TranscriptResult{
Segments: segments,
Text: text,
}, nil

}
24 changes: 23 additions & 1 deletion core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package backend
import (
"context"
"fmt"
"time"

"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/core/schema"
Expand Down Expand Up @@ -30,10 +31,31 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
return nil, fmt.Errorf("could not load whisper model")
}

return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
r, err := whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio,
Language: language,
Translate: translate,
Threads: uint32(*backendConfig.Threads),
})
if err != nil {
return nil, err
}
tr := &schema.TranscriptionResult{
Text: r.Text,
}
for _, s := range r.Segments {
var tks []int
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tr.Segments = append(tr.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
return tr, err
}
3 changes: 1 addition & 2 deletions pkg/grpc/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package grpc
import (
"context"

"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
)
Expand Down Expand Up @@ -42,7 +41,7 @@ type Backend interface {
GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...grpc.CallOption) (*pb.Result, error)
TTS(ctx context.Context, in *pb.TTSRequest, opts ...grpc.CallOption) (*pb.Result, error)
SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...grpc.CallOption) (*pb.Result, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error)
AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error)
TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error)
Status(ctx context.Context) (*pb.StatusResponse, error)

Expand Down
5 changes: 2 additions & 3 deletions pkg/grpc/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"os"

"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
gopsutil "github.com/shirou/gopsutil/v3/process"
)
Expand Down Expand Up @@ -53,8 +52,8 @@ func (llm *Base) GenerateImage(*pb.GenerateImageRequest) error {
return fmt.Errorf("unimplemented")
}

func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return schema.TranscriptionResult{}, fmt.Errorf("unimplemented")
func (llm *Base) AudioTranscription(*pb.TranscriptRequest) (pb.TranscriptResult, error) {
return pb.TranscriptResult{}, fmt.Errorf("unimplemented")
}

func (llm *Base) TTS(*pb.TTSRequest) error {
Expand Down
25 changes: 2 additions & 23 deletions pkg/grpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sync"
"time"

"github.com/mudler/LocalAI/core/schema"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
Expand Down Expand Up @@ -228,7 +227,7 @@ func (c *Client) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequ
return client.SoundGeneration(ctx, in, opts...)
}

func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*schema.TranscriptionResult, error) {
func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...grpc.CallOption) (*pb.TranscriptResult, error) {
if !c.parallel {
c.opMutex.Lock()
defer c.opMutex.Unlock()
Expand All @@ -243,27 +242,7 @@ func (c *Client) AudioTranscription(ctx context.Context, in *pb.TranscriptReques
}
defer conn.Close()
client := pb.NewBackendClient(conn)
res, err := client.AudioTranscription(ctx, in, opts...)
if err != nil {
return nil, err
}
tresult := &schema.TranscriptionResult{}
for _, s := range res.Segments {
tks := []int{}
for _, t := range s.Tokens {
tks = append(tks, int(t))
}
tresult.Segments = append(tresult.Segments,
schema.Segment{
Text: s.Text,
Id: int(s.Id),
Start: time.Duration(s.Start),
End: time.Duration(s.End),
Tokens: tks,
})
}
tresult.Text = res.Text
return tresult, err
return client.AudioTranscription(ctx, in, opts...)
}

func (c *Client) TokenizeString(ctx context.Context, in *pb.PredictOptions, opts ...grpc.CallOption) (*pb.TokenizationResponse, error) {
Expand Down
Loading