Skip to main content

lib/ash_weight.ex

defmodule AshWeight do
  @moduledoc """
  A metric weight type for [Ash](https://hexdocs.pm/ash).

  Stored canonically as integer milligrams (`BIGINT` in Postgres), with
  ergonomic conversions to grams and kilograms via `Decimal` to avoid float
  drift.

  ## Usage

      attribute :weight, AshWeight, constraints: [min: 0]

      AshWeight.new(1.5, :g)        # => %AshWeight{mg: 1500}
      AshWeight.from_kg("2.25")     # => %AshWeight{mg: 2_250_000}
      AshWeight.to_kg(weight)       # => #Decimal<0.0015>
      to_string(AshWeight.from_kg(1.5))  # => "1.5 kg"

  Accepted input shapes for `Ash.Type.cast_input/3`:

    * `%AshWeight{}`
    * integer — treated as milligrams
    * `{value, :mg | :g | :kg}` — explicit unit
    * `%{value: _, unit: _}` / `%{"value" => _, "unit" => _}` — form data
    * `nil`
  """

  use Ash.Type

  @enforce_keys [:mg]
  defstruct [:mg]

  @type unit :: :mg | :g | :kg
  @type t :: %__MODULE__{mg: integer()}

  @units [:mg, :g, :kg]

  @g_per_mg 1_000
  @kg_per_mg 1_000_000

  # ---------------------------------------------------------------------------
  # Constructors
  # ---------------------------------------------------------------------------

  @doc """
  Builds a weight from `value` in the given metric `unit` (default `:mg`).

  `value` may be an integer, float, `Decimal`, or numeric string. Internally
  normalized through `Decimal` arithmetic and rounded to integer milligrams,
  so `new(0.1, :g)` is exactly `100` mg with no float drift.
  """
  @spec new(value :: integer() | float() | Decimal.t() | String.t(), unit) :: t()
  def new(value, unit \\ :mg)
  def new(value, :mg), do: from_mg(value)
  def new(value, :g), do: from_g(value)
  def new(value, :kg), do: from_kg(value)

  @doc "Builds a weight from a milligram value."
  @spec from_mg(integer() | float() | Decimal.t() | String.t()) :: t()
  def from_mg(value) when is_integer(value), do: %__MODULE__{mg: value}
  def from_mg(value), do: %__MODULE__{mg: to_int_mg(value, 1)}

  @doc "Builds a weight from a gram value (mass × 1_000 mg)."
  @spec from_g(integer() | float() | Decimal.t() | String.t()) :: t()
  def from_g(value), do: %__MODULE__{mg: to_int_mg(value, @g_per_mg)}

  @doc "Builds a weight from a kilogram value (mass × 1_000_000 mg)."
  @spec from_kg(integer() | float() | Decimal.t() | String.t()) :: t()
  def from_kg(value), do: %__MODULE__{mg: to_int_mg(value, @kg_per_mg)}

  # ---------------------------------------------------------------------------
  # Conversions
  # ---------------------------------------------------------------------------

  @doc "Returns the weight in integer milligrams."
  @spec to_mg(t()) :: integer()
  def to_mg(%__MODULE__{mg: mg}), do: mg

  @doc "Returns the weight in grams as a `Decimal`."
  @spec to_g(t()) :: Decimal.t()
  def to_g(%__MODULE__{mg: mg}), do: Decimal.div(Decimal.new(mg), @g_per_mg)

  @doc "Returns the weight in kilograms as a `Decimal`."
  @spec to_kg(t()) :: Decimal.t()
  def to_kg(%__MODULE__{mg: mg}), do: Decimal.div(Decimal.new(mg), @kg_per_mg)

  # ---------------------------------------------------------------------------
  # Arithmetic / comparison
  # ---------------------------------------------------------------------------

  @doc "Adds two weights."
  @spec add(t(), t()) :: t()
  def add(%__MODULE__{mg: a}, %__MODULE__{mg: b}), do: %__MODULE__{mg: a + b}

  @doc "Subtracts `b` from `a`."
  @spec subtract(t(), t()) :: t()
  def subtract(%__MODULE__{mg: a}, %__MODULE__{mg: b}), do: %__MODULE__{mg: a - b}

  @doc "Multiplies a weight by a scalar."
  @spec multiply(t(), integer() | float() | Decimal.t() | String.t()) :: t()
  def multiply(%__MODULE__{mg: mg}, scalar) when is_integer(scalar),
    do: %__MODULE__{mg: mg * scalar}

  def multiply(%__MODULE__{mg: mg}, scalar) do
    %__MODULE__{
      mg:
        Decimal.new(mg)
        |> Decimal.mult(to_decimal(scalar))
        |> Decimal.round(0)
        |> Decimal.to_integer()
    }
  end

  @doc "Compares two weights. Returns `:lt`, `:eq`, or `:gt`."
  @spec compare(t(), t()) :: :lt | :eq | :gt
  def compare(%__MODULE__{mg: a}, %__MODULE__{mg: b}) when a < b, do: :lt
  def compare(%__MODULE__{mg: a}, %__MODULE__{mg: b}) when a > b, do: :gt
  def compare(%__MODULE__{}, %__MODULE__{}), do: :eq

  # ---------------------------------------------------------------------------
  # Ash.Type behaviour
  # ---------------------------------------------------------------------------

  @impl Ash.Type
  def storage_type(_constraints), do: :integer

  @impl Ash.Type
  def constraints do
    [
      min: [
        type: :integer,
        doc: "Minimum weight in milligrams (inclusive)."
      ],
      max: [
        type: :integer,
        doc: "Maximum weight in milligrams (inclusive)."
      ]
    ]
  end

  @impl Ash.Type
  def cast_input(nil, _), do: {:ok, nil}
  def cast_input(%__MODULE__{} = weight, _), do: {:ok, weight}
  def cast_input(mg, _) when is_integer(mg), do: {:ok, %__MODULE__{mg: mg}}

  def cast_input({value, unit}, _) when unit in @units do
    {:ok, new(value, unit)}
  rescue
    _ -> {:error, "could not cast weight #{inspect(value)} #{unit}"}
  end

  def cast_input(map, _) when is_map(map) and not is_struct(map) do
    with {:ok, value} <- fetch_field(map, :value),
         {:ok, unit_raw} <- fetch_field(map, :unit),
         {:ok, unit} <- normalize_unit(unit_raw) do
      try do
        {:ok, new(value, unit)}
      rescue
        _ -> {:error, "could not cast weight #{inspect(value)} #{unit}"}
      end
    end
  end

  def cast_input(other, _), do: {:error, "invalid weight: #{inspect(other)}"}

  @impl Ash.Type
  def cast_stored(nil, _), do: {:ok, nil}
  def cast_stored(mg, _) when is_integer(mg), do: {:ok, %__MODULE__{mg: mg}}
  def cast_stored(other, _), do: {:error, "invalid stored weight: #{inspect(other)}"}

  @impl Ash.Type
  def dump_to_native(nil, _), do: {:ok, nil}
  def dump_to_native(%__MODULE__{mg: mg}, _), do: {:ok, mg}
  def dump_to_native(_, _), do: :error

  @impl Ash.Type
  def matches_type?(%__MODULE__{}, _), do: true
  def matches_type?(_, _), do: false

  @impl Ash.Type
  def apply_constraints(nil, _), do: {:ok, nil}

  def apply_constraints(%__MODULE__{mg: mg} = weight, constraints) do
    with :ok <- check_min(mg, Keyword.get(constraints, :min)),
         :ok <- check_max(mg, Keyword.get(constraints, :max)) do
      {:ok, weight}
    end
  end

  @impl Ash.Type
  def equal?(%__MODULE__{mg: a}, %__MODULE__{mg: b}), do: a == b
  def equal?(_, _), do: false

  # ---------------------------------------------------------------------------
  # Internals
  # ---------------------------------------------------------------------------

  defp to_int_mg(value, factor) do
    value
    |> to_decimal()
    |> Decimal.mult(factor)
    |> Decimal.round(0)
    |> Decimal.to_integer()
  end

  defp to_decimal(v) when is_integer(v), do: Decimal.new(v)
  defp to_decimal(v) when is_float(v), do: Decimal.from_float(v)
  defp to_decimal(v) when is_binary(v), do: Decimal.new(v)
  defp to_decimal(%Decimal{} = v), do: v

  defp fetch_field(map, key) do
    with :error <- Map.fetch(map, key),
         :error <- Map.fetch(map, Atom.to_string(key)) do
      {:error, "missing field: #{key}"}
    end
  end

  defp normalize_unit(unit) when unit in @units, do: {:ok, unit}
  defp normalize_unit("mg"), do: {:ok, :mg}
  defp normalize_unit("g"), do: {:ok, :g}
  defp normalize_unit("kg"), do: {:ok, :kg}
  defp normalize_unit(other), do: {:error, "invalid unit: #{inspect(other)}"}

  defp check_min(_mg, nil), do: :ok
  defp check_min(mg, min) when mg >= min, do: :ok
  defp check_min(mg, min), do: {:error, "must be at least #{min} mg, got #{mg}"}

  defp check_max(_mg, nil), do: :ok
  defp check_max(mg, max) when mg <= max, do: :ok
  defp check_max(mg, max), do: {:error, "must be at most #{max} mg, got #{mg}"}
end

defimpl String.Chars, for: AshWeight do
  @g 1_000
  @kg 1_000_000

  def to_string(%AshWeight{mg: 0}), do: "0 mg"

  def to_string(%AshWeight{mg: mg}) when abs(mg) >= @kg do
    "#{Decimal.div(Decimal.new(mg), @kg)} kg"
  end

  def to_string(%AshWeight{mg: mg}) when abs(mg) >= @g do
    "#{Decimal.div(Decimal.new(mg), @g)} g"
  end

  def to_string(%AshWeight{mg: mg}), do: "#{mg} mg"
end

defimpl Inspect, for: AshWeight do
  def inspect(weight, _opts), do: "#AshWeight<#{to_string(weight)}>"
end