From 3f73e4192b43c0f6753ee97e72421013f4428d13 Mon Sep 17 00:00:00 2001 From: Andrew Kane Date: Sat, 18 May 2024 11:47:19 -0400 Subject: [PATCH] Added support for sparsevec type --- CHANGELOG.md | 2 +- src/SparseVector.php | 64 ++++++++++++++++++++++++++++++++++++ src/laravel/SparseVector.php | 42 +++++++++++++++++++++++ tests/LaravelTest.php | 47 ++++++++++++++++++++++++-- 4 files changed, 152 insertions(+), 3 deletions(-) create mode 100644 src/SparseVector.php create mode 100644 src/laravel/SparseVector.php diff --git a/CHANGELOG.md b/CHANGELOG.md index c534953..dd50c2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ ## 0.2.0 (unreleased) -- Added support for `halfvec` type +- Added support for `halfvec` and `sparsevec` types - Added `L1`, `Hamming`, and `Jaccard` distances - Changed `Distance` to enum - Dropped support for PHP < 8.1 diff --git a/src/SparseVector.php b/src/SparseVector.php new file mode 100644 index 0000000..3d73b7d --- /dev/null +++ b/src/SparseVector.php @@ -0,0 +1,64 @@ +dimensions = $dimensions; + $this->indices = $indices; + $this->values = $values; + } + + public static function fromDense($value) + { + $dimensions = count($value); + $indices = []; + $values = []; + foreach ($value as $i => $v) { + if ($v != 0) { + $indices[] = $i; + $values[] = floatval($v); + } + } + return new SparseVector($dimensions, $indices, $values); + } + + public static function fromString($value) + { + $parts = explode('/', $value, 2); + $dimensions = intval($parts[1]); + $indices = []; + $values = []; + $elements = explode(',', substr($parts[0], 1, -1)); + foreach ($elements as $e) { + $ep = explode(':', $e, 2); + $indices[] = intval($ep[0]) - 1; + $values[] = floatval($ep[1]); + } + return new SparseVector($dimensions, $indices, $values); + } + + public function __toString() + { + $elements = []; + for ($i = 0; $i < count($this->indices); $i++) { + $elements[] = ($this->indices[$i] + 1) . ':' . $this->values[$i]; + } + return '{' . implode(',', $elements) . '}/' . $this->dimensions; + } + + public function toArray() + { + $result = array_fill(0, $this->dimensions, 0.0); + for ($i = 0; $i < count($this->indices); $i++) { + $result[$this->indices[$i]] = $this->values[$i]; + } + return $result; + } +} diff --git a/src/laravel/SparseVector.php b/src/laravel/SparseVector.php new file mode 100644 index 0000000..7c29d25 --- /dev/null +++ b/src/laravel/SparseVector.php @@ -0,0 +1,42 @@ +addConnection([ @@ -36,8 +37,8 @@ class Item extends Model use HasNeighbors; public $timestamps = false; - protected $fillable = ['id', 'embedding', 'half_embedding', 'binary_embedding']; - protected $casts = ['embedding' => Vector::class, 'half_embedding' => HalfVector::class]; + protected $fillable = ['id', 'embedding', 'half_embedding', 'binary_embedding', 'sparse_embedding']; + protected $casts = ['embedding' => Vector::class, 'half_embedding' => HalfVector::class, 'sparse_embedding' => SparseVector::class]; } final class LaravelTest extends TestCase @@ -172,6 +173,48 @@ public function testBitJaccardDistance() $this->assertEqualsWithDelta([0, 1/3, 1], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001); } + public function testSparsevecL2Distance() + { + $this->createItems('sparse_embedding'); + $neighbors = Item::orderByRaw('sparse_embedding <-> ?', [SparseVector::fromDense([1, 1, 1])])->take(5)->get(); + $this->assertEquals([1, 3, 2], $neighbors->pluck('id')->toArray()); + $this->assertEquals([[1, 1, 1], [1, 1, 2], [2, 2, 2]], array_map(fn ($v) => $v->toArray(), $neighbors->pluck('sparse_embedding')->toArray())); + } + + public function testSparsevecScopeL2Distance() + { + $this->createItems('sparse_embedding'); + $neighbors = Item::query()->nearestNeighbors('sparse_embedding', '{1:1,2:1,3:1}/3', Distance::L2)->take(5)->get(); + $this->assertEquals([1, 3, 2], $neighbors->pluck('id')->toArray()); + $this->assertEqualsWithDelta([0, 1, sqrt(3)], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001); + } + + public function testSparsevecScopeMaxInnerProduct() + { + $this->createItems('sparse_embedding'); + $neighbors = Item::query()->nearestNeighbors('sparse_embedding', '{1:1,2:1,3:1}/3', Distance::InnerProduct)->take(5)->get(); + $this->assertEquals([2, 3, 1], $neighbors->pluck('id')->toArray()); + $this->assertEqualsWithDelta([6, 4, 3], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001); + } + + public function testSparsevecInstance() + { + $this->createItems('sparse_embedding'); + $item = Item::find(1); + $neighbors = $item->nearestNeighbors('sparse_embedding', Distance::L2)->take(5)->get(); + $this->assertEquals([3, 2], $neighbors->pluck('id')->toArray()); + $this->assertEqualsWithDelta([1, sqrt(3)], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001); + } + + public function testSparsevecInstanceL1() + { + $this->createItems('sparse_embedding'); + $item = Item::find(1); + $neighbors = $item->nearestNeighbors('sparse_embedding', Distance::L1)->take(5)->get(); + $this->assertEquals([3, 2], $neighbors->pluck('id')->toArray()); + $this->assertEqualsWithDelta([1, 3], $neighbors->pluck('neighbor_distance')->toArray(), 0.00001); + } + public function testMissingAttribute() { $this->expectException(MissingAttributeException::class);