Skip to content

Commit

Permalink
Added support for sparsevec type
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 18, 2024
1 parent 851d9d2 commit 3f73e41
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## 0.2.0 (unreleased)

- Added support for `halfvec` type
- Added support for `halfvec` and `sparsevec` types
- Added `L1`, `Hamming`, and `Jaccard` distances
- Changed `Distance` to enum
- Dropped support for PHP < 8.1
Expand Down
64 changes: 64 additions & 0 deletions src/SparseVector.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<?php

namespace Pgvector;

class SparseVector
{
protected $dimensions;
protected $indices;
protected $values;

public function __construct($dimensions, $indices, $values)
{
$this->dimensions = $dimensions;
$this->indices = $indices;
$this->values = $values;
}

public static function fromDense($value)
{
$dimensions = count($value);
$indices = [];
$values = [];
foreach ($value as $i => $v) {
if ($v != 0) {
$indices[] = $i;
$values[] = floatval($v);
}
}
return new SparseVector($dimensions, $indices, $values);
}

public static function fromString($value)
{
$parts = explode('/', $value, 2);
$dimensions = intval($parts[1]);
$indices = [];
$values = [];
$elements = explode(',', substr($parts[0], 1, -1));
foreach ($elements as $e) {
$ep = explode(':', $e, 2);
$indices[] = intval($ep[0]) - 1;
$values[] = floatval($ep[1]);
}
return new SparseVector($dimensions, $indices, $values);
}

public function __toString()
{
$elements = [];
for ($i = 0; $i < count($this->indices); $i++) {
$elements[] = ($this->indices[$i] + 1) . ':' . $this->values[$i];
}
return '{' . implode(',', $elements) . '}/' . $this->dimensions;
}

public function toArray()
{
$result = array_fill(0, $this->dimensions, 0.0);
for ($i = 0; $i < count($this->indices); $i++) {
$result[$this->indices[$i]] = $this->values[$i];
}
return $result;
}
}
42 changes: 42 additions & 0 deletions src/laravel/SparseVector.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<?php

namespace Pgvector\Laravel;

use Illuminate\Contracts\Database\Eloquent\Castable;
use Illuminate\Contracts\Database\Eloquent\CastsAttributes;
use Illuminate\Database\Eloquent\Model;

class SparseVector extends \Pgvector\SparseVector implements Castable
{
public static function castUsing(array $arguments): CastsAttributes
{
return new class ($arguments) implements CastsAttributes {
public function __construct(array $arguments)
{
// no need for dimensions
}

public function get(mixed $model, string $key, mixed $value, array $attributes): ?\Pgvector\SparseVector
{
if (is_null($value)) {
return null;
}

return SparseVector::fromString($value);
}

public function set(mixed $model, string $key, mixed $value, array $attributes): ?string
{
if (is_null($value)) {
return null;
}

if (is_array($value)) {
$value = SparseVector::fromDense($value);
}

return (string) $value;
}
};
}
}
47 changes: 45 additions & 2 deletions tests/LaravelTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
use Pgvector\Laravel\HasNeighbors;
use Pgvector\Laravel\Vector;
use Pgvector\Laravel\HalfVector;
use Pgvector\Laravel\SparseVector;

$capsule = new Capsule();
$capsule->addConnection([
Expand Down Expand Up @@ -36,8 +37,8 @@ class Item extends Model
use HasNeighbors;

public $timestamps = false;
protected $fillable = ['id', 'embedding', 'half_embedding', 'binary_embedding'];
protected $casts = ['embedding' => Vector::class, 'half_embedding' => HalfVector::class];
protected $fillable = ['id', 'embedding', 'half_embedding', 'binary_embedding', 'sparse_embedding'];
protected $casts = ['embedding' => Vector::class, 'half_embedding' => HalfVector::class, 'sparse_embedding' => SparseVector::class];
}

final class LaravelTest extends TestCase
Expand Down Expand Up @@ -172,6 +173,48 @@ public function testBitJaccardDistance()
$this->assertEqualsWithDelta([0, 1/3, 1], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001);
}

public function testSparsevecL2Distance()
{
$this->createItems('sparse_embedding');
$neighbors = Item::orderByRaw('sparse_embedding <-> ?', [SparseVector::fromDense([1, 1, 1])])->take(5)->get();
$this->assertEquals([1, 3, 2], $neighbors->pluck('id')->toArray());
$this->assertEquals([[1, 1, 1], [1, 1, 2], [2, 2, 2]], array_map(fn ($v) => $v->toArray(), $neighbors->pluck('sparse_embedding')->toArray()));
}

public function testSparsevecScopeL2Distance()
{
$this->createItems('sparse_embedding');
$neighbors = Item::query()->nearestNeighbors('sparse_embedding', '{1:1,2:1,3:1}/3', Distance::L2)->take(5)->get();
$this->assertEquals([1, 3, 2], $neighbors->pluck('id')->toArray());
$this->assertEqualsWithDelta([0, 1, sqrt(3)], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001);
}

public function testSparsevecScopeMaxInnerProduct()
{
$this->createItems('sparse_embedding');
$neighbors = Item::query()->nearestNeighbors('sparse_embedding', '{1:1,2:1,3:1}/3', Distance::InnerProduct)->take(5)->get();
$this->assertEquals([2, 3, 1], $neighbors->pluck('id')->toArray());
$this->assertEqualsWithDelta([6, 4, 3], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001);
}

public function testSparsevecInstance()
{
$this->createItems('sparse_embedding');
$item = Item::find(1);
$neighbors = $item->nearestNeighbors('sparse_embedding', Distance::L2)->take(5)->get();
$this->assertEquals([3, 2], $neighbors->pluck('id')->toArray());
$this->assertEqualsWithDelta([1, sqrt(3)], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001);
}

public function testSparsevecInstanceL1()
{
$this->createItems('sparse_embedding');
$item = Item::find(1);
$neighbors = $item->nearestNeighbors('sparse_embedding', Distance::L1)->take(5)->get();
$this->assertEquals([3, 2], $neighbors->pluck('id')->toArray());
$this->assertEqualsWithDelta([1, 3], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001);
}

public function testMissingAttribute()
{
$this->expectException(MissingAttributeException::class);
Expand Down

0 comments on commit 3f73e41

Please sign in to comment.