Skip to content

Commit

Permalink
Moved utils to separate module
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed May 19, 2024
1 parent 7acad90 commit a717213
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
20 changes: 0 additions & 20 deletions lib/pgvector.ex
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,6 @@ defmodule Pgvector do
Pgvector.Extensions.Sparsevec
]
end

# TODO move / improve pattern
@doc false
def to_sql(vector) when is_struct(vector, Pgvector.HalfVector) do
vector
end

def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do
vector
end

def to_sql(vector) do
vector |> Pgvector.new()
end

# TODO move / improve pattern
@doc false
def to_bit_sql(vector) when is_bitstring(vector) do
vector
end
end

defimpl Inspect, for: Pgvector do
Expand Down
12 changes: 6 additions & 6 deletions lib/pgvector/ecto/query.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro l2_distance(column, value) do
quote do
fragment("(? <-> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
fragment("(? <-> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
end
end

Expand All @@ -18,7 +18,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro max_inner_product(column, value) do
quote do
fragment("(? <#> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
fragment("(? <#> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
end
end

Expand All @@ -27,7 +27,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro cosine_distance(column, value) do
quote do
fragment("(? <=> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
fragment("(? <=> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
end
end

Expand All @@ -36,7 +36,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro l1_distance(column, value) do
quote do
fragment("(? <+> ?)", unquote(column), ^Pgvector.to_sql(unquote(value)))
fragment("(? <+> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_sql(unquote(value)))
end
end

Expand All @@ -45,7 +45,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro hamming_distance(column, value) do
quote do
fragment("(? <~> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
fragment("(? <~> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_bit_sql(unquote(value)))
end
end

Expand All @@ -54,7 +54,7 @@ if Code.ensure_loaded?(Ecto) do
"""
defmacro jaccard_distance(column, value) do
quote do
fragment("(? <%> ?)", unquote(column), ^Pgvector.to_bit_sql(unquote(value)))
fragment("(? <%> ?)", unquote(column), ^Pgvector.Ecto.Utils.to_bit_sql(unquote(value)))
end
end
end
Expand Down
20 changes: 20 additions & 0 deletions lib/pgvector/ecto/utils.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# TODO improve pattern
defmodule Pgvector.Ecto.Utils do
@moduledoc false

def to_sql(vector) when is_struct(vector, Pgvector.HalfVector) do
vector
end

def to_sql(vector) when is_struct(vector, Pgvector.SparseVector) do
vector
end

def to_sql(vector) do
vector |> Pgvector.new()
end

def to_bit_sql(vector) when is_bitstring(vector) do
vector
end
end

0 comments on commit a717213

Please sign in to comment.