Skip to content

Commit

Permalink
vertexai(test): Run corpora test in go coroutine to reduce test runti…
Browse files Browse the repository at this point in the history
…me (#10841)
  • Loading branch information
happy-qiao authored Sep 17, 2024
1 parent 37866ce commit 4e3e3ec
Showing 1 changed file with 46 additions and 34 deletions.
80 changes: 46 additions & 34 deletions vertexai/genai/tokenizer/corpora_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ import (
"net/http"
"os"
"strings"
"sync"
"testing"

"cloud.google.com/go/vertexai/genai"

"golang.org/x/text/encoding"
"golang.org/x/text/encoding/charmap"
"golang.org/x/text/encoding/japanese"
Expand Down Expand Up @@ -241,47 +243,57 @@ func TestCountTokensWithCorpora(t *testing.T) {
model := client.GenerativeModel(defaultModel)
ucr := newUdhrCorpus()

tok, err := New(defaultModel)
if err != nil {
log.Fatal(err)
}

corporaURL := "https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/packages/corpora/udhr.zip"
files, err := corporaGenerator(corporaURL)
corporaFiles, err := corporaGenerator(corporaURL)
if err != nil {
t.Fatalf("Failed to generate corpora: %v", err)
}

// Iterate over files generated by the generator function
for _, fileInfo := range files {
if ucr.shouldSkip(fileInfo.Name) {
fmt.Printf("Skipping file: %s\n", fileInfo.Name)
continue
}

enc, found := ucr.getEncoding(fileInfo.Name)
if !found {
fmt.Printf("No encoding found for file: %s\n", fileInfo.Name)
continue
}

decodedContent, err := decodeBytes(enc, fileInfo.Content)
if err != nil {
log.Fatalf("Failed to decode bytes: %v", err)
}
// Manage up to 10 corpora run simultaneously
workLimiter := make(chan struct{}, 10)
defer close(workLimiter)
var wg sync.WaitGroup
for _, corpora := range corporaFiles {
wg.Add(1)
go func(corpora corporaInfo) {
workLimiter <- struct{}{}
defer func() {
<-workLimiter
wg.Done()
}()
if ucr.shouldSkip(corpora.Name) {
log.Printf("Skipping file: %s\n", corpora.Name)
return
}

tok, err := New(defaultModel)
if err != nil {
log.Fatal(err)
}
enc, found := ucr.getEncoding(corpora.Name)
if !found {
log.Printf("No encoding found for file: %s\n", corpora.Name)
return
}

localNtoks, err := tok.CountTokens(genai.Text(decodedContent))
if err != nil {
log.Fatal(err)
}
remoteNtoks, err := model.CountTokens(ctx, genai.Text(decodedContent))
if err != nil {
log.Fatal(fileInfo.Name, err)
}
if localNtoks.TotalTokens != remoteNtoks.TotalTokens {
t.Errorf("expected %d(remote count-token results), but got %d(local count-token results)", remoteNtoks, localNtoks)
}
decodedContent, err := decodeBytes(enc, corpora.Content)
if err != nil {
log.Fatalf("Failed to decode bytes: %v", err)
}

localNtoks, err := tok.CountTokens(genai.Text(decodedContent))
if err != nil {
log.Fatal(err)
}
remoteNtoks, err := model.CountTokens(ctx, genai.Text(decodedContent))
if err != nil {
log.Fatal(corpora.Name, err)
}
if localNtoks.TotalTokens != remoteNtoks.TotalTokens {
t.Errorf("expected %d(remote count-token results), but got %d(local count-token results)", remoteNtoks, localNtoks)
}
}(corpora)
}

wg.Wait()
}

0 comments on commit 4e3e3ec

Please sign in to comment.