%%%-------------------------------------------------------------------
%%% erlrithmetician - linear algebra library for Erlang
%%% Copyright (C) 2026 E. G. Bland
%%%
%%% This program is free software: you can redistribute it and/or modify
%%% it under the terms of the GNU General Public License as published by
%%% the Free Software Foundation, either version 3 of the License, or
%%% (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%%% GNU General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License
%%% along with this program. If not, see <https://www.gnu.org/licenses/>.
%%%-------------------------------------------------------------------
-module(vec).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-export([new/1, new2/2, new3/3, new4/4, gen/2, dim/1, get/2, x/1, y/1, z/1, w/1, map/2, combine/3, dot/2, mag/1, normalise/1, neg/1, add/2, sub/2, muls/2, mulv/2, equal/2]).
-export_type([vec/1]).
-record(vec, { dim :: pos_integer(), data :: array:array() }).
%-opaque vec() :: #vec{}.
-opaque vec(T) :: #vec{ dim :: pos_integer(), data :: array:array(T) }.
-spec new(array:array(T)) -> vec(T).
-spec new2(T, T) -> vec(T).
-spec new3(T, T, T) -> vec(T).
-spec new4(T, T, T, T) -> vec(T).
-spec gen(fun((non_neg_integer()) -> T), pos_integer()) -> vec(T).
-spec dim(vec(_)) -> pos_integer().
-spec get(non_neg_integer(), vec(T)) -> T.
-spec x(vec(T)) -> T.
-spec y(vec(T)) -> T.
-spec z(vec(T)) -> T.
-spec w(vec(T)) -> T.
-spec map(fun((T) -> U), vec(T)) -> vec(U).
-spec combine(fun((T, U) -> V), vec(T), vec(U)) -> vec(V).
-spec dot(vec(number()), vec(number())) -> number().
-spec mag(vec(number())) -> float().
-spec normalise(vec(number())) -> vec(number()).
-spec neg(vec(number())) -> vec(number()).
-spec add(vec(number()), vec(number())) -> vec(number()).
-spec sub(vec(number()), vec(number())) -> vec(number()).
-spec muls(number(), vec(number())) -> vec(number()).
-spec mulv(vec(number()), vec(number())) -> vec(number()).
-spec equal(vec(T), vec(T)) -> boolean().
-spec new_unchecked(array:array(T)) -> vec(T).
-spec data(vec(T)) -> array:array(T).
-spec err_zero_dim_vec() -> none().
-spec err_idx_out_of_bounds(non_neg_integer(), pos_integer()) -> none().
-spec err_dim_mismatch(pos_integer(), pos_integer()) -> none().
-define(dim(V), ((V)#vec.dim)).
new(Data) ->
case array:size(Data) of
0 -> err_zero_dim_vec();
_ -> new_unchecked(Data)
end.
new2(X, Y) -> new_unchecked(array:from_list([X, Y])).
new3(X, Y, Z) -> new_unchecked(array:from_list([X, Y, Z])).
new4(X, Y, Z, W) -> new_unchecked(array:from_list([X, Y, Z, W])).
gen(F, N) ->
Data = array:from_list([F(Idx) || Idx <- lists:seq(0, N - 1)]),
new_unchecked(Data).
dim(V) -> ?dim(V).
get(Idx, V) ->
case Idx of
Idx when Idx < ?dim(V) -> array:get(Idx, data(V));
Idx -> err_idx_out_of_bounds(Idx, ?dim(V))
end.
x(V) -> get(0, V).
y(V) -> get(1, V).
z(V) -> get(2, V).
w(V) -> get(3, V).
map(F, V) ->
NewData = array:map(fun(_, X) -> F(X) end, data(V)),
new_unchecked(NewData).
combine(F, U, V) ->
Pairs = helpers:zip_arrays(data(U), data(V)),
NewData = array:map(fun(_, { X, Y }) -> F(X, Y) end, Pairs),
new_unchecked(NewData).
dot(U, V) when ?dim(U) =:= ?dim(V) ->
Pairs = helpers:zip_arrays(data(U), data(V)),
array:foldl(fun(_, { X, Y }, Acc) -> Acc + X * Y end, 0, Pairs);
dot(U, V) -> err_dim_mismatch(dim(U), dim(V)).
mag(V) -> math:sqrt(dot(V, V)).
normalise(V) ->
Mag = mag(V),
map(fun(X) -> X / Mag end, V).
neg(V) -> map(fun(X) -> -X end, V).
add(U, V) -> combine(fun(X, Y) -> X + Y end, U, V).
sub(U, V) -> combine(fun(X, Y) -> X - Y end, U, V).
muls(L, V) -> map(fun(X) -> L * X end, V).
mulv(U, V) -> combine(fun(X, Y) -> X * Y end, U, V).
equal(U, V) ->
dim(U) =:= dim(V) andalso
lists:all(fun(I) -> get(I, U) =:= get(I, V) end, lists:seq(0, dim(U) - 1)).
new_unchecked(Data) -> #vec{ dim = array:size(Data), data = array:fix(Data) }.
data(V) -> V#vec.data.
err_zero_dim_vec() -> error({ zero_dim }).
err_idx_out_of_bounds(Idx, Dim) -> error({ index_out_of_bounds, { index, Idx }, { dim, Dim } }).
err_dim_mismatch(DimExpected, DimActual) -> error({ dim_mismatch, { expected, DimExpected }, { actual, DimActual } }).
% tests
-ifdef(TEST).
new2_test() ->
V = new2(2, 3),
?assertEqual(2, dim(V)),
?assertEqual(2, x(V)),
?assertEqual(3, y(V)),
?assertError({ index_out_of_bounds, { index, 2 }, { dim, 2 } }, z(V)),
?assertError({ index_out_of_bounds, { index, 3 }, { dim, 2 } }, w(V)).
new3_test() ->
V = new3(5, 7, 11),
?assertEqual(3, dim(V)),
?assertEqual(5, x(V)),
?assertEqual(7, y(V)),
?assertEqual(11, z(V)),
?assertError({ index_out_of_bounds, { index, 3 }, { dim, 3 } }, w(V)).
new4_test() ->
V = new4(13, 17, 19, 23),
?assertEqual(4, dim(V)),
?assertEqual(13, x(V)),
?assertEqual(17, y(V)),
?assertEqual(19, z(V)),
?assertEqual(23, w(V)).
new_empty_fails_test() ->
?assertError({ zero_dim }, new(array:from_list([]))).
new_dim_xyzw_get_test() ->
V = new(array:from_list([29, 31, 37, 41, 43])),
?assertEqual(5, dim(V)),
?assertEqual(29, x(V)),
?assertEqual(29, get(0, V)),
?assertEqual(31, y(V)),
?assertEqual(31, get(1, V)),
?assertEqual(37, z(V)),
?assertEqual(37, get(2, V)),
?assertEqual(41, w(V)),
?assertEqual(41, get(3, V)),
?assertEqual(43, get(4, V)),
?assertError({ index_out_of_bounds, { index, 5 }, { dim, 5 } }, get(5, V)).
map_test() ->
V = new3(47, 53, 59),
F = fun(X) -> 2 * X - 1 end,
VMappedExpected = new3(93, 105, 117),
VMappedActual = map(F, V),
?assert(equal(VMappedExpected, VMappedActual)).
combine_test() ->
U = new3(61, 67, 71),
V = new3(73, 79, 83),
F = fun(X, Y) -> 2 * X + 3 * Y end,
CombinedExpected = new3(341, 371, 391),
CombinedActual = combine(F, U, V),
?assert(equal(CombinedExpected, CombinedActual)).
neg_test() ->
V = new3(-89, 97, -101),
VNegExpected = new3(89, -97, 101),
VNegActual = neg(V),
?assert(equal(VNegExpected, VNegActual)).
add_test() ->
U = new3(103, 107, 109),
V = new3(113, 127, 131),
SumExpected = new3(216, 234, 240),
SumActual = add(U, V),
?assert(equal(SumExpected, SumActual)).
sub_test() ->
U = new3(151, 139, 163),
V = new3(137, 157, 149),
DiffExpected = new3(14, -18, 14),
DiffActual = sub(U, V),
?assert(equal(DiffExpected, DiffActual)).
muls_test() ->
V = new3(167, 173, 179),
L = 181,
ProdExpected = new3(30227, 31313, 32399),
ProdActual = muls(L, V),
?assert(equal(ProdExpected, ProdActual)).
mulv_test() ->
U = new3(191, 193, 197),
V = new3(199, 211, 223),
ProdExpected = new3(38009, 40723, 43931),
ProdActual = mulv(U, V),
?assert(equal(ProdExpected, ProdActual)).
-endif.