%%%-------------------------------------------------------------------
%%% erlrithmetician - linear algebra library for Erlang
%%% Copyright (C) 2026 E. G. Bland
%%%
%%% This program is free software: you can redistribute it and/or modify
%%% it under the terms of the GNU General Public License as published by
%%% the Free Software Foundation, either version 3 of the License, or
%%% (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
%%% GNU General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License
%%% along with this program. If not, see <https://www.gnu.org/licenses/>.
%%%-------------------------------------------------------------------
-module(mat).
-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.
-export([new/1, gen/3, identity/1, get/3, get_col/2, dim_cols/1, dim_rows/1, v0/1, x0/1, y0/1, z0/1, w0/1, v1/1, x1/1, y1/1, z1/1, w1/1, v2/1, x2/1, y2/1, z2/1, w2/1, v3/1, x3/1, y3/1, z3/1, w3/1, map/2, combine/3, neg/1, add/2, sub/2, muls/2, mulv/2, mulm/2, equal/2]).
-export_type([mat/1]).
-record(mat, { dim_rows :: pos_integer(), dim_cols :: pos_integer(), cols :: array:array() }).
-opaque mat(T) :: #mat{ dim_rows :: pos_integer(), dim_cols :: pos_integer(), cols :: array:array(vec:vec(T)) }.
-spec new(array:array(vec:vec(T))) -> mat(T).
-spec gen(fun((non_neg_integer(), non_neg_integer()) -> T), pos_integer(), pos_integer()) -> mat(T).
-spec identity(pos_integer()) -> mat:mat(number()).
-spec get(non_neg_integer(), non_neg_integer(), mat(T)) -> T.
-spec get_col(non_neg_integer(), mat(T)) -> vec:vec(T).
-spec dim_rows(mat(_)) -> pos_integer().
-spec dim_cols(mat(_)) -> pos_integer().
-spec v0(mat(T)) -> vec:vec(T).
-spec v1(mat(T)) -> vec:vec(T).
-spec v2(mat(T)) -> vec:vec(T).
-spec v3(mat(T)) -> vec:vec(T).
-spec x0(mat(T)) -> T.
-spec y0(mat(T)) -> T.
-spec z0(mat(T)) -> T.
-spec w0(mat(T)) -> T.
-spec x1(mat(T)) -> T.
-spec y1(mat(T)) -> T.
-spec z1(mat(T)) -> T.
-spec w1(mat(T)) -> T.
-spec x2(mat(T)) -> T.
-spec y2(mat(T)) -> T.
-spec z2(mat(T)) -> T.
-spec w2(mat(T)) -> T.
-spec x3(mat(T)) -> T.
-spec y3(mat(T)) -> T.
-spec z3(mat(T)) -> T.
-spec w3(mat(T)) -> T.
-spec map(fun((T) -> U), mat(T)) -> mat(U).
-spec combine(fun((T, U) -> V), mat(T), mat(U)) -> mat(V).
-spec neg(mat(number())) -> mat(number()).
-spec add(mat(number()), mat(number())) -> mat(number()).
-spec sub(mat(number()), mat(number())) -> mat(number()).
-spec muls(mat(number()), number()) -> mat(number()).
-spec mulv(mat(number()), vec:vec(number())) -> vec:vec(number()).
-spec mulm(mat(number()), mat(number())) -> mat(number()).
-spec equal(mat(T), mat(T)) -> boolean().
-spec cols(mat(T)) -> array:array(vec:vec(T)).
-spec err_zero_cols() -> none().
-spec err_jagged_mat(non_neg_integer(), non_neg_integer(), pos_integer()) -> none().
-spec err_col_idx_out_of_bounds(non_neg_integer(), pos_integer()) -> none().
-spec err_row_idx_out_of_bounds(non_neg_integer(), pos_integer()) -> none().
-define(cols(M), ((M)#mat.cols)).
-define(dim_cols(M), ((M)#mat.dim_cols)).
-define(dim_rows(M), ((M)#mat.dim_rows)).
new(Cols) ->
case array:size(Cols) of
0 -> err_zero_cols();
DimCols ->
ExpectedDimRows = vec:dim(array:get(0, Cols)),
array:foldl(
fun(ColIdx, Col, ok) ->
RowsInCol = vec:dim(Col),
case RowsInCol of
ExpectedDimRows -> ok;
_ -> err_jagged_mat(ColIdx, RowsInCol, ExpectedDimRows)
end
end,
ok,
Cols
),
#mat { dim_cols = DimCols, dim_rows = ExpectedDimRows, cols = array:fix(Cols) }
end.
gen(F, M, N) ->
Cols = array:from_list([
vec:gen(fun(RowIdx) -> F(RowIdx, ColIdx) end, M) || ColIdx <- lists:seq(0, N - 1)
]),
new(Cols).
identity(N) ->
gen(fun(RowIdx, ColIdx) -> case RowIdx =:= ColIdx of false -> 0; true -> 1 end end, N, N).
get(RowIdx, ColIdx, M) ->
case RowIdx of
RowIdx when RowIdx < ?dim_rows(M) -> vec:get(RowIdx, get_col(ColIdx, M));
RowIdx -> err_row_idx_out_of_bounds(RowIdx, dim_cols(M))
end.
get_col(Idx, M) ->
case Idx of
Idx when Idx < ?dim_cols(M) -> array:get(Idx, cols(M));
Idx -> err_col_idx_out_of_bounds(Idx, dim_cols(M))
end.
v0(M) -> get_col(0, M).
v1(M) -> get_col(1, M).
v2(M) -> get_col(2, M).
v3(M) -> get_col(3, M).
x0(M) -> get(0, 0, M).
y0(M) -> get(1, 0, M).
z0(M) -> get(2, 0, M).
w0(M) -> get(3, 0, M).
x1(M) -> get(0, 1, M).
y1(M) -> get(1, 1, M).
z1(M) -> get(2, 1, M).
w1(M) -> get(3, 1, M).
x2(M) -> get(0, 2, M).
y2(M) -> get(1, 2, M).
z2(M) -> get(2, 2, M).
w2(M) -> get(3, 2, M).
x3(M) -> get(0, 3, M).
y3(M) -> get(1, 3, M).
z3(M) -> get(2, 3, M).
w3(M) -> get(3, 3, M).
map(F, M) ->
NewCols = array:map(fun(_, Col) -> vec:map(F, Col) end, cols(M)),
new(NewCols).
combine(F, M, N) ->
Pairs = helpers:zip_arrays(cols(M), cols(N)),
NewCols = array:map(fun(_, { U, V }) -> vec:combine(F, U, V) end, Pairs),
new(NewCols).
neg(M) -> map(fun(X) -> -X end, M).
add(M, N) -> combine(fun(X, Y) -> X + Y end, M, N).
sub(M, N) -> combine(fun(X, Y) -> X - Y end, M, N).
muls(M, L) -> map(fun(X) -> L * X end, M).
mulv(M, V) ->
N = new(array:from_list([V])),
Product = mulm(M, N),
v0(Product).
mulm(M, N) when ?dim_cols(M) =:= ?dim_rows(N) ->
ProductRows = dim_rows(M),
ProductCols = dim_cols(N),
CommonDim = dim_cols(M),
gen(
fun(RowIdx, ColIdx) ->
lists:sum([get(RowIdx, I, M) * get(I, ColIdx, N) || I <- lists:seq(0, CommonDim - 1)])
end,
ProductRows,
ProductCols
);
mulm(M, N) -> error(err_mul_mismatched_dims(?dim_rows(M), ?dim_cols(M), ?dim_rows(N), ?dim_cols(N))).
equal(M, N) ->
dim_rows(M) =:= dim_rows(N) andalso
dim_cols(M) =:= dim_cols(N) andalso
lists:all(fun(I) -> lists:all(fun(J) -> get(I, J, M) =:= get(I, J, N) end, lists:seq(0, dim_cols(M) - 1)) end, lists:seq(0, dim_rows(M) - 1)).
dim_cols(M) -> ?dim_cols(M).
dim_rows(M) -> ?dim_rows(M).
cols(M) -> ?cols(M).
err_zero_cols() -> error({ zero_cols }).
err_jagged_mat(ColIdx, RowsInCol, ExpectedRows) -> error({ jagged, { col_idx, ColIdx }, { rows_in_col, RowsInCol }, { expected_rows, ExpectedRows } }).
err_col_idx_out_of_bounds(Idx, Dim) -> error({ col_index_out_of_bounds, { index, Idx }, { dim, Dim } }).
err_row_idx_out_of_bounds(Idx, Dim) -> error({ row_index_out_of_bounds, { index, Idx }, { dim, Dim } }).
err_mul_mismatched_dims(MDimRows, MDimCols, NDimRows, NDimCols) ->
error(
{ mul_mismatched_dims, {
lhs, { dim_rows, MDimRows }, { dim_cols, MDimCols }
},
rhs, { dim_rows, NDimRows }, { dim_cols, NDimCols }
}
).
% tests
-ifdef(TEST).
new_empty_fails_test() ->
?assertError({ zero_cols }, new(array:from_list([]))).
new_jagged_fails_test() ->
Col1 = vec:new3(2, 3, 5),
Col2 = vec:new2(7, 11),
Cols = array:from_list([Col1, Col2]),
?assertError({ jagged, { col_idx, 1 }, { rows_in_col, 2 }, { expected_rows, 3 } }, new(Cols)).
new_vxyzw_get_test() ->
Col1 = vec:new(array:from_list([13, 17, 19, 23, 29])),
Col2 = vec:new(array:from_list([31, 37, 41, 43, 47])),
Col3 = vec:new(array:from_list([53, 59, 61, 67, 71])),
Col4 = vec:new(array:from_list([73, 79, 83, 89, 97])),
Col5 = vec:new(array:from_list([101, 103, 107, 109, 113])),
Col6 = vec:new(array:from_list([127, 131, 137, 139, 149])),
M = new(array:from_list([Col1, Col2, Col3, Col4, Col5, Col6])),
?assert(vec:equal(Col1, v0(M))),
?assert(vec:equal(Col1, get_col(0, M))),
?assert(vec:equal(Col2, v1(M))),
?assert(vec:equal(Col2, get_col(1, M))),
?assert(vec:equal(Col3, v2(M))),
?assert(vec:equal(Col3, get_col(2, M))),
?assert(vec:equal(Col4, v3(M))),
?assert(vec:equal(Col4, get_col(3, M))),
?assert(vec:equal(Col5, get_col(4, M))),
?assert(vec:equal(Col6, get_col(5, M))),
?assertEqual(13, x0(M)),
?assertEqual(13, get(0, 0, M)),
?assertEqual(17, y0(M)),
?assertEqual(17, get(1, 0, M)),
?assertEqual(19, z0(M)),
?assertEqual(19, get(2, 0, M)),
?assertEqual(23, w0(M)),
?assertEqual(23, get(3, 0, M)),
?assertEqual(29, get(4, 0, M)),
?assertEqual(31, x1(M)),
?assertEqual(31, get(0, 1, M)),
?assertEqual(37, y1(M)),
?assertEqual(37, get(1, 1, M)),
?assertEqual(41, z1(M)),
?assertEqual(41, get(2, 1, M)),
?assertEqual(43, w1(M)),
?assertEqual(43, get(3, 1, M)),
?assertEqual(47, get(4, 1, M)),
?assertEqual(53, x2(M)),
?assertEqual(53, get(0, 2, M)),
?assertEqual(59, y2(M)),
?assertEqual(59, get(1, 2, M)),
?assertEqual(61, z2(M)),
?assertEqual(61, get(2, 2, M)),
?assertEqual(67, w2(M)),
?assertEqual(67, get(3, 2, M)),
?assertEqual(71, get(4, 2, M)),
?assertEqual(73, x3(M)),
?assertEqual(73, get(0, 3, M)),
?assertEqual(79, y3(M)),
?assertEqual(79, get(1, 3, M)),
?assertEqual(83, z3(M)),
?assertEqual(83, get(2, 3, M)),
?assertEqual(89, w3(M)),
?assertEqual(89, get(3, 3, M)),
?assertEqual(97, get(4, 3, M)),
?assertEqual(101, get(0, 4, M)),
?assertEqual(103, get(1, 4, M)),
?assertEqual(107, get(2, 4, M)),
?assertEqual(109, get(3, 4, M)),
?assertEqual(113, get(4, 4, M)),
?assertEqual(127, get(0, 5, M)),
?assertEqual(131, get(1, 5, M)),
?assertEqual(137, get(2, 5, M)),
?assertEqual(139, get(3, 5, M)),
?assertEqual(149, get(4, 5, M)).
neg_test() ->
Mat = new(array:from_list([vec:new2(151, -157), vec:new2(-163, 167), vec:new2(173, 179)])),
NegExpected = new(array:from_list([vec:new2(-151, 157), vec:new2(163, -167), vec:new2(-173, -179)])),
NegActual = neg(Mat),
?assert(equal(NegExpected, NegActual)).
add_test() ->
M = new(array:from_list([vec:new2(181, 191), vec:new2(193, 197)])),
N = new(array:from_list([vec:new2(199, 211), vec:new2(223, 227)])),
SumExpected = new(array:from_list([vec:new2(380, 402), vec:new2(416, 424)])),
SumActual = add(M, N),
?assert(equal(SumExpected, SumActual)).
sub_test() ->
M = new(array:from_list([vec:new2(229, -233), vec:new2(-239, 241)])),
N = new(array:from_list([vec:new2(-251, -257), vec:new2(263, 269)])),
DiffExpected = new(array:from_list([vec:new2(480, 24), vec:new2(-502, -28)])),
DiffActual = sub(M, N),
?assert(equal(DiffExpected, DiffActual)).
mulm_test() ->
M = new(array:from_list([vec:new4(271, 277, 281, 283), vec:new4(293, 307, 311, 313)])),
N = new(array:from_list([vec:new2(317, 331), vec:new2(337, 347), vec:new2(349, 353)])),
ProductExpected = new(array:from_list([vec:new4(182890, 189426, 192018, 193314), vec:new4(192998, 199878, 202614, 203982), vec:new4(198008, 205044, 207852, 209256)])),
ProductActual = mulm(M, N),
?assert(equal(ProductExpected, ProductActual)).
-endif.