lib/pgvector/extensions/vector.ex

if Code.ensure_loaded?(Postgrex) do
  defmodule Pgvector.Extensions.Vector do
    import Postgrex.BinaryUtils, warn: false

    def init(opts), do: Keyword.get(opts, :decode_binary, :copy)

    def matching(_), do: [type: "vector"]

    def format(_), do: :binary

    def encode(_) do
      quote do
        vec when is_list(vec) ->
          data = unquote(__MODULE__).encode_vector(vec)
          [<<IO.iodata_length(data)::int32()>> | data]
      end
    end

    def decode(_) do
      quote do
        <<len::int32(), data::binary-size(len)>> ->
          unquote(__MODULE__).decode_vector(data)
      end
    end

    def encode_vector(vec) do
      dim = length(vec)
      bin = Enum.map(vec, fn v -> <<v::float32>> end)
      [<<dim::uint16>>, <<0::uint16>> | bin]
    end

    def decode_vector(<<dim::uint16, 0::uint16, bin::binary-size(dim)-unit(32)>>) do
      for <<v::float32 <- bin>>, do: v
    end
  end
end