diff --git a/genai/client_test.go b/genai/client_test.go index b266127..ed22eab 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -76,7 +76,7 @@ func TestLive(t *testing.T) { t.Run("streaming-counting", func(t *testing.T) { // Verify only that we don't crash. See #18. - iter := model.GenerateContentStream(ctx, Text("count 1 to 100.")) + iter := model.GenerateContentStream(ctx, Text("count 1 to 10.")) _ = responsesString(t, iter) }) t.Run("streaming-error", func(t *testing.T) { @@ -150,8 +150,9 @@ func TestLive(t *testing.T) { }) t.Run("blocked", func(t *testing.T) { + t.Skip("skipping until we find a prompt that is blocked") // Only happens with streaming at the moment. - iter := model.GenerateContentStream(ctx, Text("How do I make a bomb?")) + iter := model.GenerateContentStream(ctx, Text("???")) resps, err := all(iter) if err == nil { for _, r := range resps { @@ -174,38 +175,46 @@ func TestLive(t *testing.T) { } }) t.Run("max-tokens", func(t *testing.T) { + // Verify that setting max output tokens truncates the response. + // (It does not result in FinishReasonMaxTokens.) maxModel := client.GenerativeModel(defaultModel) maxModel.Temperature = Ptr(float32(0)) - maxModel.SetMaxOutputTokens(10) + maxModel.SetMaxOutputTokens(3) res, err := maxModel.GenerateContent(ctx, Text("What is a dog?")) if err != nil { t.Fatal(err) } - got := res.Candidates[0].FinishReason - want := FinishReasonMaxTokens - if got != want && got != FinishReasonOther { // TODO: should not need FinishReasonOther - t.Errorf("got %s, want %s", got, want) + if got, want := responseString(res), "A dog is"; got != want { + t.Errorf("got %q, want %q", got, want) + } + gotr := res.Candidates[0].FinishReason + wantr := FinishReasonStop + if gotr != wantr { + t.Errorf("got %s, want %s", gotr, wantr) } }) t.Run("max-tokens-streaming", func(t *testing.T) { maxModel := client.GenerativeModel(defaultModel) maxModel.Temperature = Ptr[float32](0) - maxModel.MaxOutputTokens = Ptr[int32](10) + maxModel.MaxOutputTokens = Ptr[int32](3) iter := maxModel.GenerateContentStream(ctx, Text("What is a dog?")) - var merged *GenerateContentResponse for { - res, err := iter.Next() + _, err := iter.Next() if err == iterator.Done { break } if err != nil { t.Fatal(err) } - merged = joinResponses(merged, res) } - want := FinishReasonMaxTokens - if got := merged.Candidates[0].FinishReason; got != want && got != FinishReasonOther { // TODO: see above - t.Errorf("got %s, want %s", got, want) + res := iter.MergedResponse() + if got, want := responseString(res), "A dog is"; got != want { + t.Errorf("got %q, want %q", got, want) + } + gotr := res.Candidates[0].FinishReason + wantr := FinishReasonStop + if gotr != wantr { + t.Errorf("got %s, want %s", gotr, wantr) } }) t.Run("count-tokens", func(t *testing.T) {