Skip to content

Commit

Permalink
feat(gemini): support structured outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
mattmcdev committed Feb 10, 2025
1 parent 7827377 commit 10ed5fb
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 7 deletions.
2 changes: 1 addition & 1 deletion docs/components/ProviderSupport.vue
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ export default {
{
name: "Gemini",
text: Supported,
structured: Planned,
structured: Supported,
embeddings: Supported,
image: Supported,
tools: Supported,
Expand Down
6 changes: 1 addition & 5 deletions docs/providers/gemini.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,4 @@
'api_key' => env('GEMINI_API_KEY', ''),
'url' => env('GEMINI_URL', 'https://generativelanguage.googleapis.com/v1beta/models'),
],
```

## Limitations

- The structured output is not supported.
```
8 changes: 7 additions & 1 deletion src/Providers/Gemini/Gemini.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
use EchoLabs\Prism\Embeddings\Request as EmbeddingRequest;
use EchoLabs\Prism\Embeddings\Response as EmbeddingResponse;
use EchoLabs\Prism\Providers\Gemini\Handlers\Embeddings;
use EchoLabs\Prism\Providers\Gemini\Handlers\Structured;
use EchoLabs\Prism\Providers\Gemini\Handlers\Text;
use EchoLabs\Prism\Structured\Request as StructuredRequest;
use EchoLabs\Prism\Text\Request as TextRequest;
Expand Down Expand Up @@ -36,7 +37,12 @@ public function text(TextRequest $request): ProviderResponse
#[\Override]
public function structured(StructuredRequest $request): ProviderResponse
{
throw new \Exception(sprintf('%s does not support structured mode', class_basename($this)));
$handler = new Structured($this->client(
$request->clientOptions,
$request->clientRetry
));

return $handler->handle($request);
}

#[\Override]
Expand Down
82 changes: 82 additions & 0 deletions src/Providers/Gemini/Handlers/Structured.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
<?php

namespace EchoLabs\Prism\Providers\Gemini\Handlers;

use EchoLabs\Prism\Exceptions\PrismException;
use EchoLabs\Prism\Providers\Gemini\Maps\FinishReasonMap;
use EchoLabs\Prism\Providers\Gemini\Maps\MessageMap;
use EchoLabs\Prism\Providers\Gemini\Maps\SchemaMap;
use EchoLabs\Prism\Structured\Request;
use EchoLabs\Prism\ValueObjects\ProviderResponse;
use EchoLabs\Prism\ValueObjects\ResponseMeta;
use EchoLabs\Prism\ValueObjects\Usage;
use Illuminate\Http\Client\PendingRequest;
use Illuminate\Http\Client\Response;
use Throwable;

class Structured
{
public function __construct(protected PendingRequest $client) {}

public function handle(Request $request): ProviderResponse
{
try {
$response = $this->sendRequest($request);
} catch (Throwable $e) {
throw PrismException::providerRequestError($request->model, $e);
}

$data = $response->json();

if (! $data || data_get($data, 'error')) {
throw PrismException::providerResponseError(vsprintf(
'Gemini Error: [%s] %s',
[
data_get($data, 'error.code', 'unknown'),
data_get($data, 'error.message', 'unknown'),
]
));
}

return new ProviderResponse(
text: data_get($data, 'candidates.0.content.parts.0.text') ?? '',
toolCalls: [],
usage: new Usage(
data_get($data, 'usageMetadata.promptTokenCount', 0),
data_get($data, 'usageMetadata.candidatesTokenCount', 0)
),
finishReason: FinishReasonMap::map(
data_get($data, 'candidates.0.finishReason'),
),
responseMeta: new ResponseMeta(
id: data_get($data, 'id', ''),
model: data_get($data, 'modelVersion'),
)
);
}

public function sendRequest(Request $request): Response
{
$endpoint = "{$request->model}:generateContent";

$payload = (new MessageMap($request->messages, $request->systemPrompt))();

$responseSchema = new SchemaMap($request->schema);

$payload['generationConfig'] = array_merge([
'response_mime_type' => 'application/json',
'response_schema' => $responseSchema->toArray(),
], array_filter([
'temperature' => $request->temperature,
'topP' => $request->topP,
'maxOutputTokens' => $request->maxTokens,
]));

$safetySettings = data_get($request->providerMeta, 'safetySettings');
if (! empty($safetySettings)) {
$payload['safetySettings'] = $safetySettings;
}

return $this->client->post($endpoint, $payload);
}
}
53 changes: 53 additions & 0 deletions src/Providers/Gemini/Maps/SchemaMap.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
<?php

namespace EchoLabs\Prism\Providers\Gemini\Maps;

use EchoLabs\Prism\Contracts\Schema;
use EchoLabs\Prism\Schema\ArraySchema;
use EchoLabs\Prism\Schema\BooleanSchema;
use EchoLabs\Prism\Schema\NumberSchema;
use EchoLabs\Prism\Schema\ObjectSchema;

class SchemaMap
{
public function __construct(
private Schema $schema,
) {}

public function toArray(): array

Check failure on line 17 in src/Providers/Gemini/Maps/SchemaMap.php

View workflow job for this annotation

GitHub Actions / phpstan

Method EchoLabs\Prism\Providers\Gemini\Maps\SchemaMap::toArray() return type has no value type specified in iterable type array.
{
return array_merge([
'type' => $this->mapType(),
...array_filter([
...$this->schema->toArray(),
'additionalProperties' => null,
]),
], array_filter([
'items' => property_exists($this->schema, 'items') ?
(new self($this->schema->items))->toArray() :
null,
'properties' => property_exists($this->schema, 'properties') ?
array_reduce($this->schema->properties, fn (array $carry, Schema $property) => [
...$carry,
$property->name() => (new self($property))->toArray(),
], []) :
null,
]));
}

protected function mapType(): string
{
switch ($this->schema::class) {
case ArraySchema::class:
return 'array';
case BooleanSchema::class:
return 'boolean';
case NumberSchema::class:
return 'number';
case ObjectSchema::class:
return 'object';
default:
return 'string';
}
}
}

0 comments on commit 10ed5fb

Please sign in to comment.