lib/slip39/share.ex

defmodule Slip39.Share do
  @moduledoc """
  Represents a single mnemonic share
  """

  alias Slip39.Rs1024
  alias Slip39.Constants
  alias Slip39.Share
  alias Slip39.Utils

  defstruct [
    :identifier,
    :iteration_exponent,
    :group_index,
    :group_threshold,
    :group_count,
    :member_index,
    :member_threshold,
    :data
  ]

  def mnemonic_to_indices(mnemonic) do
    mnemonic
    |> String.split(" ", trim: true)
    |> Enum.map(fn word ->
      Constants.get!(:word_list_map_by_string)[word]
    end)
  end

  def mnemonic_to_binary(mnemonic) do
    mnemonic
    |> mnemonic_to_indices()
    |> Enum.reduce(<<>>, fn idx, acc ->
      <<acc::bitstring, idx::size(Constants.get!(:RADIX_BITS))>>
    end)
  end

  def binary_to_mnemonic(data) do
    radix_bits = Constants.get!(:RADIX_BITS)

    data
    |> Utils.chunk_bits(radix_bits)
    |> Enum.map(fn word ->
      <<indice::size(radix_bits)>> = word
      indice
    end)
    |> Enum.map(fn indice ->
      Constants.get!(:word_list_map_by_idx)[indice]
    end)
  end

  @doc """
  This function will encode a `Slip39.Share` struct into a mnemonic string.
  """
  @spec encode!(%Slip39.Share{}) :: list(String.t())
  def encode!(share = %Share{}) do
    first_part =
      <<share.identifier::15, share.iteration_exponent::5, share.group_index::4,
        share.group_threshold - 1::4, share.group_count - 1::4, share.member_index::4,
        share.member_threshold - 1::4>>

    # The padded share value is the share value, left paded with zero
    # so that its length becomes the nearest multiple of 10.
    ps_bitsize = Utils.nearest_upper_multiple(bit_size(share.data), Constants.get!(:RADIX_BITS))
    padding_size = ps_bitsize - bit_size(share.data)

    ps_bits = <<(<<0::size(padding_size)>>)::bitstring, (<<share.data::bitstring>>)>>

    checksum = Rs1024.create_checksum(<<(<<first_part::bitstring>>), (<<ps_bits::bitstring>>)>>)

    <<(<<first_part::bitstring>>), <<ps_bits::bitstring>>, (<<checksum::bitstring>>)>>
    |> binary_to_mnemonic()
  end

  @doc """
  This function will decode a mnemonic string into a `Slip39.Share` struct.

  ## Example

      iex(1)> Slip39.Share.decode!("kernel leader acrobat romp camera unusual fawn engage revenue total blimp quiet muscle clinic slush mouse watch estimate custody glimpse")
      %Slip39.Share{
        identifier: 15856,
        iteration_exponent: 0,
        group_index: 0,
        group_threshold: 2,
        group_count: 4,
        member_index: 0,
        member_threshold: 1,
        data: <<116, 239, 149, 20, 122, 245, 231, 69, 107, 62, 90, 37, 51, 169, 87,
          227>>
      }
  """
  @spec decode!(String.t()) :: %Share{}
  def decode!(mnemonic) do
    raw_bits = mnemonic |> mnemonic_to_binary()

    <<identifier::15, iteration_exponent::5, group_index::4, group_threshold::4, group_count::4,
      member_index::4, member_threshold::4, ps_and_checksum::bitstring>> = raw_bits

    # We can pattern match on unkown size only on last items ... so we have to hack a bit
    # to extract padded share value (ps) and checksum (c), the latest is known to be of
    # a constant 30 bits length.
    checksum_size = Constants.get!(:CHECKSUM_SIZE)

    ps_size = bit_size(ps_and_checksum) - checksum_size
    <<ps::size(ps_size), _c::size(checksum_size)>> = ps_and_checksum

    # Since share value must be an even byte count, we can deduct the
    # padding size
    padding_size = (bit_size(raw_bits) - Constants.get!(:METADATA_LENGTH_BITS)) |> Integer.mod(16)
    share_size = ps_size - padding_size

    if padding_size > 8 do
      raise "Invalid mnemonic length (invalid padding)."
    end

    if bit_size(raw_bits) < Constants.get!(:MIN_MNEMONIC_LENGTH_BITS) do
      raise("Invalid mnemonic length (too short).")
    end

    <<_padding::size(padding_size), data::size(share_size)>> = <<ps::size(ps_size)>>

    if Slip39.Rs1024.is_checksum_valid?(raw_bits) do
      %Share{
        identifier: identifier,
        iteration_exponent: iteration_exponent,
        group_index: group_index,
        group_threshold: group_threshold + 1,
        group_count: group_count + 1,
        member_index: member_index,
        member_threshold: member_threshold + 1,
        data: <<data::size(share_size)>>
      }
    else
      raise("bad checksum")
    end
  end

  def get_group_parameters(%Share{} = share) do
    %Slip39.Share.GroupParameters{
      identifier: share.identifier,
      iteration_exponent: share.iteration_exponent,
      group_index: share.group_index,
      group_threshold: share.group_threshold,
      group_count: share.group_count,
      member_threshold: share.member_threshold
    }
  end

  def get_common_parameters(%Share{} = share) do
    %Slip39.Share.CommonParameters{
      identifier: share.identifier,
      iteration_exponent: share.iteration_exponent,
      group_index: share.group_index,
      group_threshold: share.group_threshold,
      group_count: share.group_count
    }
  end

  @spec get_raw_shares(list(Share)) :: list(Share.RawShare)
  def get_raw_shares(shares) do
    shares
    |> Enum.with_index()
    |> Enum.map(fn {share, _idx} ->
      %Share.RawShare{x: share.member_index, data: share.data}
    end)
  end
end