Skip to main content

src/mat.erl

%%%-------------------------------------------------------------------
%%% erlrithmetician - linear algebra library for Erlang
%%% Copyright (C) 2026 E. G. Bland
%%%
%%% This program is free software: you can redistribute it and/or modify
%%% it under the terms of the GNU General Public License as published by
%%% the Free Software Foundation, either version 3 of the License, or
%%% (at your option) any later version.
%%%
%%% This program is distributed in the hope that it will be useful,
%%% but WITHOUT ANY WARRANTY; without even the implied warranty of
%%% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
%%% GNU General Public License for more details.
%%%
%%% You should have received a copy of the GNU General Public License
%%% along with this program.  If not, see <https://www.gnu.org/licenses/>.
%%%-------------------------------------------------------------------

-module(mat).

-ifdef(TEST).
-include_lib("eunit/include/eunit.hrl").
-endif.

-export([new/1, gen/3, identity/1, get/3, get_col/2, dim_cols/1, dim_rows/1, v0/1, x0/1, y0/1, z0/1, w0/1, v1/1, x1/1, y1/1, z1/1, w1/1, v2/1, x2/1, y2/1, z2/1, w2/1, v3/1, x3/1, y3/1, z3/1, w3/1, map/2, combine/3, neg/1, add/2, sub/2, muls/2, mulv/2, mulm/2, equal/2]).
-export_type([mat/1]).

-record(mat, { dim_rows :: pos_integer(), dim_cols :: pos_integer(), cols :: array:array() }).
-opaque mat(T) :: #mat{ dim_rows :: pos_integer(), dim_cols :: pos_integer(), cols :: array:array(vec:vec(T)) }.

-spec new(array:array(vec:vec(T))) -> mat(T).
-spec gen(fun((non_neg_integer(), non_neg_integer()) -> T), pos_integer(), pos_integer()) -> mat(T).
-spec identity(pos_integer()) -> mat:mat(number()).
-spec get(non_neg_integer(), non_neg_integer(), mat(T)) -> T.
-spec get_col(non_neg_integer(), mat(T)) -> vec:vec(T).
-spec dim_rows(mat(_)) -> pos_integer().
-spec dim_cols(mat(_)) -> pos_integer().
-spec v0(mat(T)) -> vec:vec(T).
-spec v1(mat(T)) -> vec:vec(T).
-spec v2(mat(T)) -> vec:vec(T).
-spec v3(mat(T)) -> vec:vec(T).
-spec x0(mat(T)) -> T.
-spec y0(mat(T)) -> T.
-spec z0(mat(T)) -> T.
-spec w0(mat(T)) -> T.
-spec x1(mat(T)) -> T.
-spec y1(mat(T)) -> T.
-spec z1(mat(T)) -> T.
-spec w1(mat(T)) -> T.
-spec x2(mat(T)) -> T.
-spec y2(mat(T)) -> T.
-spec z2(mat(T)) -> T.
-spec w2(mat(T)) -> T.
-spec x3(mat(T)) -> T.
-spec y3(mat(T)) -> T.
-spec z3(mat(T)) -> T.
-spec w3(mat(T)) -> T.
-spec map(fun((T) -> U), mat(T)) -> mat(U).
-spec combine(fun((T, U) -> V), mat(T), mat(U)) -> mat(V).
-spec neg(mat(number())) -> mat(number()).
-spec add(mat(number()), mat(number())) -> mat(number()).
-spec sub(mat(number()), mat(number())) -> mat(number()).
-spec muls(mat(number()), number()) -> mat(number()).
-spec mulv(mat(number()), vec:vec(number())) -> vec:vec(number()).
-spec mulm(mat(number()), mat(number())) -> mat(number()).
-spec equal(mat(T), mat(T)) -> boolean().

-spec cols(mat(T)) -> array:array(vec:vec(T)).
-spec err_zero_cols() -> none().
-spec err_jagged_mat(non_neg_integer(), non_neg_integer(), pos_integer()) -> none().
-spec err_col_idx_out_of_bounds(non_neg_integer(), pos_integer()) -> none().
-spec err_row_idx_out_of_bounds(non_neg_integer(), pos_integer()) -> none().

-define(cols(M), ((M)#mat.cols)).
-define(dim_cols(M), ((M)#mat.dim_cols)).
-define(dim_rows(M), ((M)#mat.dim_rows)).

new(Cols) ->
  case array:size(Cols) of
    0 -> err_zero_cols();
    DimCols ->
      ExpectedDimRows = vec:dim(array:get(0, Cols)),
      array:foldl(
        fun(ColIdx, Col, ok) ->
          RowsInCol = vec:dim(Col),
          case RowsInCol of
            ExpectedDimRows -> ok;
            _ -> err_jagged_mat(ColIdx, RowsInCol, ExpectedDimRows)
          end
        end,
        ok,
        Cols
      ),
      #mat { dim_cols = DimCols, dim_rows = ExpectedDimRows, cols = array:fix(Cols) }
  end.

gen(F, M, N) ->
  Cols = array:from_list([
    vec:gen(fun(RowIdx) -> F(RowIdx, ColIdx) end, M) || ColIdx <- lists:seq(0, N - 1)
  ]),
  new(Cols).

identity(N) ->
  gen(fun(RowIdx, ColIdx) -> case RowIdx =:= ColIdx of false -> 0; true -> 1 end end, N, N).

get(RowIdx, ColIdx, M) ->
  case RowIdx of
    RowIdx when RowIdx < ?dim_rows(M) -> vec:get(RowIdx, get_col(ColIdx, M));
    RowIdx -> err_row_idx_out_of_bounds(RowIdx, dim_cols(M))
  end.

get_col(Idx, M) ->
  case Idx of
    Idx when Idx < ?dim_cols(M) -> array:get(Idx, cols(M));
    Idx -> err_col_idx_out_of_bounds(Idx, dim_cols(M))
  end.

v0(M) -> get_col(0, M).
v1(M) -> get_col(1, M).
v2(M) -> get_col(2, M).
v3(M) -> get_col(3, M).
x0(M) -> get(0, 0, M).
y0(M) -> get(1, 0, M).
z0(M) -> get(2, 0, M).
w0(M) -> get(3, 0, M).
x1(M) -> get(0, 1, M).
y1(M) -> get(1, 1, M).
z1(M) -> get(2, 1, M).
w1(M) -> get(3, 1, M).
x2(M) -> get(0, 2, M).
y2(M) -> get(1, 2, M).
z2(M) -> get(2, 2, M).
w2(M) -> get(3, 2, M).
x3(M) -> get(0, 3, M).
y3(M) -> get(1, 3, M).
z3(M) -> get(2, 3, M).
w3(M) -> get(3, 3, M).

map(F, M) ->
  NewCols = array:map(fun(_, Col) -> vec:map(F, Col) end, cols(M)),
  new(NewCols).

combine(F, M, N) ->
  Pairs = helpers:zip_arrays(cols(M), cols(N)),
  NewCols = array:map(fun(_, { U, V }) -> vec:combine(F, U, V) end, Pairs),
  new(NewCols).

neg(M) -> map(fun(X) -> -X end, M).

add(M, N) -> combine(fun(X, Y) -> X + Y end, M, N).

sub(M, N) -> combine(fun(X, Y) -> X - Y end, M, N).

muls(M, L) -> map(fun(X) -> L * X end, M).

mulv(M, V) ->
  N = new(array:from_list([V])),
  Product = mulm(M, N),
  v0(Product).

mulm(M, N) when ?dim_cols(M) =:= ?dim_rows(N) ->
  ProductRows = dim_rows(M),
  ProductCols = dim_cols(N),
  CommonDim = dim_cols(M),
  gen(
    fun(RowIdx, ColIdx) ->
      lists:sum([get(RowIdx, I, M) * get(I, ColIdx, N) || I <- lists:seq(0, CommonDim - 1)])
    end,
    ProductRows,
    ProductCols
  );
mulm(M, N) -> error(err_mul_mismatched_dims(?dim_rows(M), ?dim_cols(M), ?dim_rows(N), ?dim_cols(N))).

equal(M, N) ->
  dim_rows(M) =:= dim_rows(N) andalso
  dim_cols(M) =:= dim_cols(N) andalso
  lists:all(fun(I) -> lists:all(fun(J) -> get(I, J, M) =:= get(I, J, N) end, lists:seq(0, dim_cols(M) - 1)) end, lists:seq(0, dim_rows(M) - 1)).

dim_cols(M) -> ?dim_cols(M).

dim_rows(M) -> ?dim_rows(M).

cols(M) -> ?cols(M).

err_zero_cols() -> error({ zero_cols }).

err_jagged_mat(ColIdx, RowsInCol, ExpectedRows) -> error({ jagged, { col_idx, ColIdx }, { rows_in_col, RowsInCol }, { expected_rows, ExpectedRows } }).

err_col_idx_out_of_bounds(Idx, Dim) -> error({ col_index_out_of_bounds, { index, Idx }, { dim, Dim } }).

err_row_idx_out_of_bounds(Idx, Dim) -> error({ row_index_out_of_bounds, { index, Idx }, { dim, Dim } }).

err_mul_mismatched_dims(MDimRows, MDimCols, NDimRows, NDimCols) ->
  error(
    { mul_mismatched_dims, {
      lhs, { dim_rows, MDimRows }, { dim_cols, MDimCols }
    },
      rhs, { dim_rows, NDimRows }, { dim_cols, NDimCols }
    }
  ).

% tests
-ifdef(TEST).

new_empty_fails_test() ->
  ?assertError({ zero_cols }, new(array:from_list([]))).

new_jagged_fails_test() ->
  Col1 = vec:new3(2, 3, 5),
  Col2 = vec:new2(7, 11),
  Cols = array:from_list([Col1, Col2]),
  ?assertError({ jagged, { col_idx, 1 }, { rows_in_col, 2 }, { expected_rows, 3 } }, new(Cols)).

new_vxyzw_get_test() ->
  Col1 = vec:new(array:from_list([13, 17, 19, 23, 29])),
  Col2 = vec:new(array:from_list([31, 37, 41, 43, 47])),
  Col3 = vec:new(array:from_list([53, 59, 61, 67, 71])),
  Col4 = vec:new(array:from_list([73, 79, 83, 89, 97])),
  Col5 = vec:new(array:from_list([101, 103, 107, 109, 113])),
  Col6 = vec:new(array:from_list([127, 131, 137, 139, 149])),
  M = new(array:from_list([Col1, Col2, Col3, Col4, Col5, Col6])),

  ?assert(vec:equal(Col1, v0(M))),
  ?assert(vec:equal(Col1, get_col(0, M))),
  ?assert(vec:equal(Col2, v1(M))),
  ?assert(vec:equal(Col2, get_col(1, M))),
  ?assert(vec:equal(Col3, v2(M))),
  ?assert(vec:equal(Col3, get_col(2, M))),
  ?assert(vec:equal(Col4, v3(M))),
  ?assert(vec:equal(Col4, get_col(3, M))),
  ?assert(vec:equal(Col5, get_col(4, M))),
  ?assert(vec:equal(Col6, get_col(5, M))),
  ?assertEqual(13, x0(M)),
  ?assertEqual(13, get(0, 0, M)),
  ?assertEqual(17, y0(M)),
  ?assertEqual(17, get(1, 0, M)),
  ?assertEqual(19, z0(M)),
  ?assertEqual(19, get(2, 0, M)),
  ?assertEqual(23, w0(M)),
  ?assertEqual(23, get(3, 0, M)),
  ?assertEqual(29, get(4, 0, M)),
  ?assertEqual(31, x1(M)),
  ?assertEqual(31, get(0, 1, M)),
  ?assertEqual(37, y1(M)),
  ?assertEqual(37, get(1, 1, M)),
  ?assertEqual(41, z1(M)),
  ?assertEqual(41, get(2, 1, M)),
  ?assertEqual(43, w1(M)),
  ?assertEqual(43, get(3, 1, M)),
  ?assertEqual(47, get(4, 1, M)),
  ?assertEqual(53, x2(M)),
  ?assertEqual(53, get(0, 2, M)),
  ?assertEqual(59, y2(M)),
  ?assertEqual(59, get(1, 2, M)),
  ?assertEqual(61, z2(M)),
  ?assertEqual(61, get(2, 2, M)),
  ?assertEqual(67, w2(M)),
  ?assertEqual(67, get(3, 2, M)),
  ?assertEqual(71, get(4, 2, M)),
  ?assertEqual(73, x3(M)),
  ?assertEqual(73, get(0, 3, M)),
  ?assertEqual(79, y3(M)),
  ?assertEqual(79, get(1, 3, M)),
  ?assertEqual(83, z3(M)),
  ?assertEqual(83, get(2, 3, M)),
  ?assertEqual(89, w3(M)),
  ?assertEqual(89, get(3, 3, M)),
  ?assertEqual(97, get(4, 3, M)),
  ?assertEqual(101, get(0, 4, M)),
  ?assertEqual(103, get(1, 4, M)),
  ?assertEqual(107, get(2, 4, M)),
  ?assertEqual(109, get(3, 4, M)),
  ?assertEqual(113, get(4, 4, M)),
  ?assertEqual(127, get(0, 5, M)),
  ?assertEqual(131, get(1, 5, M)),
  ?assertEqual(137, get(2, 5, M)),
  ?assertEqual(139, get(3, 5, M)),
  ?assertEqual(149, get(4, 5, M)).

neg_test() ->
  Mat = new(array:from_list([vec:new2(151, -157), vec:new2(-163, 167), vec:new2(173, 179)])),
  NegExpected = new(array:from_list([vec:new2(-151, 157), vec:new2(163, -167), vec:new2(-173, -179)])),
  NegActual = neg(Mat),
  ?assert(equal(NegExpected, NegActual)).

add_test() ->
  M = new(array:from_list([vec:new2(181, 191), vec:new2(193, 197)])),
  N = new(array:from_list([vec:new2(199, 211), vec:new2(223, 227)])),
  SumExpected = new(array:from_list([vec:new2(380, 402), vec:new2(416, 424)])),
  SumActual = add(M, N),
  ?assert(equal(SumExpected, SumActual)).

sub_test() ->
  M = new(array:from_list([vec:new2(229, -233), vec:new2(-239, 241)])),
  N = new(array:from_list([vec:new2(-251, -257), vec:new2(263, 269)])),
  DiffExpected = new(array:from_list([vec:new2(480, 24), vec:new2(-502, -28)])),
  DiffActual = sub(M, N),
  ?assert(equal(DiffExpected, DiffActual)).

mulm_test() ->
  M = new(array:from_list([vec:new4(271, 277, 281, 283), vec:new4(293, 307, 311, 313)])),
  N = new(array:from_list([vec:new2(317, 331), vec:new2(337, 347), vec:new2(349, 353)])),
  ProductExpected = new(array:from_list([vec:new4(182890, 189426, 192018, 193314), vec:new4(192998, 199878, 202614, 203982), vec:new4(198008, 205044, 207852, 209256)])),
  ProductActual = mulm(M, N),
  ?assert(equal(ProductExpected, ProductActual)).

-endif.