Skip to main content

src/vec.erl

%%%-------------------------------------------------------------------
%%% erlrithmetician - linear algebra library for Erlang
%%% Copyright (C) 2026 E. G. Bland
%%%
%%% This program is free software: you can redistribute it and/or modify
%%% it under the terms of the GNU General Public License as published by
%%% the Free Software Foundation, either version 3 of the License, or
%%% (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%%% GNU General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License
%%% along with this program.  If not, see <https://www.gnu.org/licenses/>.
%%%-------------------------------------------------------------------

-module(vec).

-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.

-export([new/1, new2/2, new3/3, new4/4, gen/2, dim/1, get/2, x/1, y/1, z/1, w/1, map/2, combine/3, dot/2, mag/1, normalise/1, neg/1, add/2, sub/2, muls/2, mulv/2, equal/2]).
-export_type([vec/1]).

-record(vec, { dim :: pos_integer(), data :: array:array() }).
%-opaque vec() :: #vec{}.
-opaque vec(T) :: #vec{ dim :: pos_integer(), data :: array:array(T) }.

-spec new(array:array(T)) -> vec(T).
-spec new2(T, T) -> vec(T).
-spec new3(T, T, T) -> vec(T).
-spec new4(T, T, T, T) -> vec(T).
-spec gen(fun((non_neg_integer()) -> T), pos_integer()) -> vec(T).
-spec dim(vec(_)) -> pos_integer().
-spec get(non_neg_integer(), vec(T)) -> T.
-spec x(vec(T)) -> T.
-spec y(vec(T)) -> T.
-spec z(vec(T)) -> T.
-spec w(vec(T)) -> T.
-spec map(fun((T) -> U), vec(T)) -> vec(U).
-spec combine(fun((T, U) -> V), vec(T), vec(U)) -> vec(V).
-spec dot(vec(number()), vec(number())) -> number().
-spec mag(vec(number())) -> float().
-spec normalise(vec(number())) -> vec(number()).
-spec neg(vec(number())) -> vec(number()).
-spec add(vec(number()), vec(number())) -> vec(number()).
-spec sub(vec(number()), vec(number())) -> vec(number()).
-spec muls(number(), vec(number())) -> vec(number()).
-spec mulv(vec(number()), vec(number())) -> vec(number()).
-spec equal(vec(T), vec(T)) -> boolean().

-spec new_unchecked(array:array(T)) -> vec(T).
-spec data(vec(T)) -> array:array(T).
-spec err_zero_dim_vec() -> none().
-spec err_idx_out_of_bounds(non_neg_integer(), pos_integer()) -> none().
-spec err_dim_mismatch(pos_integer(), pos_integer()) -> none().

-define(dim(V), ((V)#vec.dim)).

new(Data) ->
  case array:size(Data) of
    0 -> err_zero_dim_vec();
    _ -> new_unchecked(Data)
  end.

new2(X, Y) -> new_unchecked(array:from_list([X, Y])).

new3(X, Y, Z) -> new_unchecked(array:from_list([X, Y, Z])).

new4(X, Y, Z, W) -> new_unchecked(array:from_list([X, Y, Z, W])).

gen(F, N) ->
  Data = array:from_list([F(Idx) || Idx <- lists:seq(0, N - 1)]),
  new_unchecked(Data).

dim(V) -> ?dim(V).

get(Idx, V) ->
  case Idx of
    Idx when Idx < ?dim(V) -> array:get(Idx, data(V));
    Idx -> err_idx_out_of_bounds(Idx, ?dim(V))
  end.

x(V) -> get(0, V).

y(V) -> get(1, V).

z(V) -> get(2, V).

w(V) -> get(3, V).

map(F, V) ->
  NewData = array:map(fun(_, X) -> F(X) end, data(V)),
  new_unchecked(NewData).

combine(F, U, V) ->
  Pairs = helpers:zip_arrays(data(U), data(V)),
  NewData = array:map(fun(_, { X, Y }) -> F(X, Y) end, Pairs),
  new_unchecked(NewData).

dot(U, V) when ?dim(U) =:= ?dim(V) ->
  Pairs = helpers:zip_arrays(data(U), data(V)),
  array:foldl(fun(_, { X, Y }, Acc) -> Acc + X * Y end, 0, Pairs);
dot(U, V) -> err_dim_mismatch(dim(U), dim(V)).

mag(V) -> math:sqrt(dot(V, V)).

normalise(V) ->
  Mag = mag(V),
  map(fun(X) -> X / Mag end, V).

neg(V) -> map(fun(X) -> -X end, V).

add(U, V) -> combine(fun(X, Y) -> X + Y end, U, V).

sub(U, V) -> combine(fun(X, Y) -> X - Y end, U, V).

muls(L, V) -> map(fun(X) -> L * X end, V).

mulv(U, V) -> combine(fun(X, Y) -> X * Y end, U, V).

equal(U, V) ->
  dim(U) =:= dim(V) andalso
  lists:all(fun(I) -> get(I, U) =:= get(I, V) end, lists:seq(0, dim(U) - 1)).

new_unchecked(Data) -> #vec{ dim = array:size(Data), data = array:fix(Data) }.

data(V) -> V#vec.data.

err_zero_dim_vec() -> error({ zero_dim }).

err_idx_out_of_bounds(Idx, Dim) -> error({ index_out_of_bounds, { index, Idx }, { dim, Dim } }).

err_dim_mismatch(DimExpected, DimActual) -> error({ dim_mismatch, { expected, DimExpected }, { actual, DimActual } }).

% tests
-ifdef(TEST).

new2_test() ->
  V = new2(2, 3),
  ?assertEqual(2, dim(V)),
  ?assertEqual(2, x(V)),
  ?assertEqual(3, y(V)),
  ?assertError({ index_out_of_bounds, { index, 2 }, { dim, 2 } }, z(V)),
  ?assertError({ index_out_of_bounds, { index, 3 }, { dim, 2 } }, w(V)).

new3_test() ->
  V = new3(5, 7, 11),
  ?assertEqual(3, dim(V)),
  ?assertEqual(5, x(V)),
  ?assertEqual(7, y(V)),
  ?assertEqual(11, z(V)),
  ?assertError({ index_out_of_bounds, { index, 3 }, { dim, 3 } }, w(V)).

new4_test() ->
  V = new4(13, 17, 19, 23),
  ?assertEqual(4, dim(V)),
  ?assertEqual(13, x(V)),
  ?assertEqual(17, y(V)),
  ?assertEqual(19, z(V)),
  ?assertEqual(23, w(V)).

new_empty_fails_test() ->
  ?assertError({ zero_dim }, new(array:from_list([]))).

new_dim_xyzw_get_test() ->
  V = new(array:from_list([29, 31, 37, 41, 43])),
  ?assertEqual(5, dim(V)),
  ?assertEqual(29, x(V)),
  ?assertEqual(29, get(0, V)),
  ?assertEqual(31, y(V)),
  ?assertEqual(31, get(1, V)),
  ?assertEqual(37, z(V)),
  ?assertEqual(37, get(2, V)),
  ?assertEqual(41, w(V)),
  ?assertEqual(41, get(3, V)),
  ?assertEqual(43, get(4, V)),
  ?assertError({ index_out_of_bounds, { index, 5 }, { dim, 5 } }, get(5, V)).

map_test() ->
  V = new3(47, 53, 59),
  F = fun(X) -> 2 * X - 1 end,
  VMappedExpected = new3(93, 105, 117),
  VMappedActual = map(F, V),
  ?assert(equal(VMappedExpected, VMappedActual)).

combine_test() ->
  U = new3(61, 67, 71),
  V = new3(73, 79, 83),
  F = fun(X, Y) -> 2 * X + 3 * Y end,
  CombinedExpected = new3(341, 371, 391),
  CombinedActual = combine(F, U, V),
  ?assert(equal(CombinedExpected, CombinedActual)).

neg_test() ->
  V = new3(-89, 97, -101),
  VNegExpected = new3(89, -97, 101),
  VNegActual = neg(V),
  ?assert(equal(VNegExpected, VNegActual)).

add_test() ->
  U = new3(103, 107, 109),
  V = new3(113, 127, 131),
  SumExpected = new3(216, 234, 240),
  SumActual = add(U, V),
  ?assert(equal(SumExpected, SumActual)).

sub_test() ->
  U = new3(151, 139, 163),
  V = new3(137, 157, 149),
  DiffExpected = new3(14, -18, 14),
  DiffActual = sub(U, V),
  ?assert(equal(DiffExpected, DiffActual)).

muls_test() ->
  V = new3(167, 173, 179),
  L = 181,
  ProdExpected = new3(30227, 31313, 32399),
  ProdActual = muls(L, V),
  ?assert(equal(ProdExpected, ProdActual)).

mulv_test() ->
  U = new3(191, 193, 197),
  V = new3(199, 211, 223),
  ProdExpected = new3(38009, 40723, 43931),
  ProdActual = mulv(U, V),
  ?assert(equal(ProdExpected, ProdActual)).

-endif.