lib/beaver/mlir/attribute.ex

defmodule Beaver.MLIR.Attribute do
  @moduledoc """
  This module defines functions parsing and creating attributes in MLIR.
  """
  import Beaver.MLIR.CAPI
  @behaviour Access

  alias Beaver.MLIR
  import Beaver.Sigils

  use Kinda.ResourceKind, forward_module: Beaver.Native

  defp raise_with_diagnostics(attr_str, diagnostics) do
    case diagnostics do
      [] ->
        raise ArgumentError, "fail to parse attribute: #{attr_str}"

      diagnostics when is_list(diagnostics) ->
        raise ArgumentError, MLIR.Diagnostic.format(diagnostics, "fail to parse attribute")
    end
  end

  def get(attr_str, opts \\ []) when is_binary(attr_str) do
    attr = MLIR.StringRef.create(attr_str)

    Beaver.Deferred.from_opts(opts, fn ctx ->
      {attr, diagnostics} = mlirAttributeParseGetWithDiagnostics(ctx, ctx, attr)

      if MLIR.null?(attr) do
        raise_with_diagnostics(attr_str, diagnostics)
      else
        attr
      end
    end)
  end

  defp truthy_elements(elements) do
    Enum.map(elements, &if(&1, do: 1, else: 0))
  end

  defp dense_elements_splat_get(shaped_type, value, ctx) do
    element_type = MLIR.Type.element_type(shaped_type)

    cond do
      match?(%MLIR.Attribute{}, value) ->
        mlirDenseElementsAttrSplatGet(shaped_type, value)

      is_boolean(value) ->
        mlirDenseElementsAttrBoolSplatGet(shaped_type, value)

      MLIR.Type.integer?(element_type) ->
        width = MLIR.Type.width(element_type)

        apply(
          MLIR.CAPI,
          if MLIR.Type.unsigned?(element_type) do
            :"mlirDenseElementsAttrUInt#{width}SplatGet"
          else
            :"mlirDenseElementsAttrInt#{width}SplatGet"
          end,
          [shaped_type, value]
        )

      MLIR.equal?(element_type, MLIR.Type.f(32, ctx: ctx)) ->
        mlirDenseElementsAttrFloatSplatGet(shaped_type, value)

      MLIR.equal?(element_type, MLIR.Type.f(64, ctx: ctx)) ->
        mlirDenseElementsAttrDoubleSplatGet(shaped_type, value)

      true ->
        raise ArgumentError, "unsupported splat value type"
    end
  end

  defp string_elements?(elements) do
    (Enum.all?(elements, &is_binary/1) or
       Enum.all?(elements, &match?(%MLIR.StringRef{}, &1))) and
      length(elements) > 0
  end

  defp attr_elements?(elements) do
    Enum.all?(elements, &match?(%MLIR.Attribute{}, &1)) and length(elements) > 0
  end

  defp int_dense_elements_attr_getter(element_type) do
    width = MLIR.Type.width(element_type)

    {f, t} =
      if MLIR.Type.unsigned?(element_type) do
        {:"mlirDenseElementsAttrUInt#{width}Get", :"Elixir.Beaver.Native.U#{width}"}
      else
        {:"mlirDenseElementsAttrInt#{width}Get", :"Elixir.Beaver.Native.I#{width}"}
      end

    &apply(MLIR.CAPI, f, [&1, &2, Beaver.Native.array(&3, t)])
  end

  defp float_dense_elements_attr_getter(element_type) do
    {f, t} =
      case MLIR.Type.width(element_type) do
        32 ->
          {&mlirDenseElementsAttrFloatGet/3, Beaver.Native.F32}

        64 ->
          {&mlirDenseElementsAttrDoubleGet/3, Beaver.Native.F64}
      end

    &apply(f, [&1, &2, Beaver.Native.array(&3, t)])
  end

  defp dense_elements_get(shaped_type, num_elements, elements) do
    element_type = MLIR.Type.element_type(shaped_type)

    cond do
      MLIR.Type.opaque?(element_type) or string_elements?(elements) ->
        strings =
          elements
          |> Enum.map(&MLIR.StringRef.create/1)
          |> Beaver.Native.array(MLIR.StringRef, mut: true)

        mlirDenseElementsAttrStringGet(shaped_type, num_elements, strings)

      attr_elements?(elements) ->
        attrs = Beaver.Native.array(elements, MLIR.Attribute)
        mlirDenseElementsAttrGet(shaped_type, num_elements, attrs)

      MLIR.equal?(element_type, MLIR.Type.i1()) ->
        elements = truthy_elements(elements)

        mlirDenseElementsAttrBoolGet(
          shaped_type,
          num_elements,
          Beaver.Native.array(elements, Beaver.Native.CInt)
        )

      MLIR.Type.integer?(element_type) ->
        int_dense_elements_attr_getter(element_type).(shaped_type, num_elements, elements)

      MLIR.equal?(element_type, MLIR.Type.index()) ->
        mlirDenseElementsAttrUInt64Get(
          shaped_type,
          num_elements,
          Beaver.Native.array(elements, Beaver.Native.U64)
        )

      MLIR.Type.float?(element_type) ->
        float_dense_elements_attr_getter(element_type).(shaped_type, num_elements, elements)

      true ->
        raise ArgumentError,
              "unsupported element type #{MLIR.to_string(element_type)}"
    end
  end

  def dense_elements(elements_or_value, shaped_type \\ {:i, 8}, opts \\ [])

  def dense_elements(elements, {t, width}, opts)
      when is_binary(elements) and t in [:i, :f] and is_integer(width) do
    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        et = apply(MLIR.Type, t, [width, [ctx: ctx]])
        str = MLIR.StringRef.create(elements)

        MLIR.Type.ranked_tensor!([byte_size(elements)], et)
        |> mlirDenseElementsAttrRawBufferGet(
          MLIR.StringRef.length(str),
          MLIR.StringRef.data(str) |> Beaver.Native.Array.as_opaque()
        )
      end
    )
  end

  def dense_elements(elements_or_value, shaped_type, opts) when is_list(elements_or_value) do
    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        shaped_type =
          case Beaver.Deferred.create(shaped_type, ctx) do
            {:ok, %Beaver.MLIR.Type{} = t} -> t
            %Beaver.MLIR.Type{} = t -> t
          end

        elements = elements_or_value |> Enum.map(&Beaver.Deferred.create(&1, ctx))
        num_elements = length(elements)

        shaped_type_num_elements =
          beaverShapedTypeGetNumElements(shaped_type) |> Beaver.Native.to_term()

        if shaped_type_num_elements != num_elements and num_elements != 1 do
          raise ArgumentError,
                "number of elements #{num_elements} does not match shaped type #{shaped_type_num_elements}"
        end

        dense_elements_get(shaped_type, num_elements, elements)
      end
    )
  end

  def dense_elements(value, shaped_type, opts) do
    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        value = Beaver.Deferred.create(value, ctx)
        shaped_type = Beaver.Deferred.create(shaped_type, ctx)
        dense_elements_splat_get(shaped_type, value, ctx)
      end
    )
  end

  def array(elements, opts \\ []) when is_list(elements) do
    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        mlirArrayAttrGet(
          ctx,
          length(elements),
          elements
          |> Enum.map(&Beaver.Deferred.create(&1, ctx))
          |> Beaver.Native.array(MLIR.Attribute)
        )
      end
    )
  end

  defp dense_array_getter(Beaver.Native.Bool), do: &mlirDenseBoolArrayGet/3
  defp dense_array_getter(Beaver.Native.I8), do: &mlirDenseI8ArrayGet/3
  defp dense_array_getter(Beaver.Native.I16), do: &mlirDenseI16ArrayGet/3
  defp dense_array_getter(Beaver.Native.I32), do: &mlirDenseI32ArrayGet/3
  defp dense_array_getter(Beaver.Native.I64), do: &mlirDenseI64ArrayGet/3
  defp dense_array_getter(Beaver.Native.F32), do: &mlirDenseF32ArrayGet/3
  defp dense_array_getter(Beaver.Native.F64), do: &mlirDenseF64ArrayGet/3

  def dense_array(elements, type, opts \\ []) when is_list(elements) do
    getter = dense_array_getter(type)

    {elements, type} =
      case type do
        Beaver.Native.Bool ->
          {Enum.map(elements, &if(&1, do: 1, else: 0)), Beaver.Native.CInt}

        _ ->
          {elements, type}
      end

    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        getter.(
          ctx,
          length(elements),
          Beaver.Native.array(elements, type)
        )
      end
    )
  end

  def dictionary(elements, opts \\ []) when is_list(elements) do
    elements =
      elements
      |> Enum.map(fn
        %MLIR.NamedAttribute{} = na -> na
        {name, attr} -> MLIR.NamedAttribute.get(name, attr)
      end)

    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        mlirDictionaryAttrGet(
          ctx,
          length(elements),
          elements
          |> Enum.map(&Beaver.Deferred.create(&1, ctx))
          |> Beaver.Native.array(MLIR.NamedAttribute)
        )
      end
    )
  end

  def string(str, opts \\ []) do
    Beaver.Deferred.from_opts(
      opts,
      &mlirStringAttrGet(&1, MLIR.StringRef.create(str))
    )
  end

  def type(t)
      when is_function(t, 1) do
    &type(t.(&1))
  end

  def type(%MLIR.Type{} = t) do
    mlirTypeAttrGet(t)
  end

  defp composite(%MLIR.Type{} = t, validate, get)
       when is_function(validate, 1) and is_function(get, 1) do
    if validate.(t), do: get.(t), else: raise(ArgumentError, "incompatible type #{to_string(t)}")
  end

  defp composite(t, validate, get) when is_function(t, 1), do: &composite(t.(&1), validate, get)

  def integer(t, value) when is_integer(value) do
    composite(
      t,
      &(MLIR.Type.integer?(&1) or MLIR.Type.index?(&1)),
      &mlirIntegerAttrGet(&1, value)
    )
  end

  def float(t, value) when is_float(value) do
    composite(t, &MLIR.Type.float?/1, &mlirFloatAttrDoubleGet(MLIR.context(&1), &1, value))
  end

  def bool(value, opts \\ []) do
    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        value =
          case value do
            true -> 1
            false -> 0
          end

        mlirBoolAttrGet(ctx, value)
      end
    )
  end

  def affine_map(map) when is_function(map, 1) do
    &affine_map(map.(&1))
  end

  def affine_map(%MLIR.AffineMap{} = map) do
    mlirAffineMapAttrGet(map)
  end

  def unit(opts \\ []) do
    Beaver.Deferred.from_opts(
      opts,
      &mlirUnitAttrGet(&1)
    )
  end

  def flat_symbol_ref(symbol, opts \\ []) do
    symbol = MLIR.StringRef.create(symbol)

    Beaver.Deferred.from_opts(
      opts,
      &mlirFlatSymbolRefAttrGet(&1, symbol)
    )
  end

  def symbol_ref(symbol, nested_symbols \\ [], opts \\ []) do
    symbol = MLIR.StringRef.create(symbol)

    Beaver.Deferred.from_opts(
      opts,
      fn ctx ->
        nested_arr =
          nested_symbols
          |> Enum.map(fn
            s when is_binary(s) -> flat_symbol_ref(s, ctx: ctx)
            %MLIR.Attribute{} = s -> s
          end)
          |> Beaver.Native.array(MLIR.Attribute)

        mlirSymbolRefAttrGet(ctx, symbol, length(nested_symbols), nested_arr)
      end
    )
  end

  def index(value, opts \\ []) when is_integer(value) do
    Beaver.Deferred.from_opts(opts, ~a{#{value} : index})
  end

  def null() do
    mlirAttributeGetNull()
  end

  for {f, "mlirAttributeIsA" <> type_name, 1} <-
        Beaver.MLIR.CAPI.__info__(:functions)
        |> Enum.map(fn {f, a} -> {f, Atom.to_string(f), a} end) do
    helper_name = type_name |> Macro.underscore()

    @doc """
    calls `Beaver.MLIR.CAPI.#{f}/1` to check if it is #{type_name} attribute
    """
    def unquote(:"#{helper_name}?")(%__MODULE__{} = t) do
      unquote(f)(t) |> Beaver.Native.to_term()
    end
  end

  defguardp is_index_or_name(key) when is_integer(key) or is_binary(key) or is_atom(key)

  @impl Access
  def fetch(attr, key) when is_index_or_name(key) do
    __MODULE__.Accessor.new(attr) |> __MODULE__.Accessor.fetch(attr, key)
  end

  @impl Access
  def get_and_update(attr, key, fun) when is_index_or_name(key) do
    __MODULE__.Accessor.new(attr) |> __MODULE__.Accessor.get_and_update(attr, key, fun)
  end

  @impl Access
  def pop(attr, key) when is_index_or_name(key) do
    if MLIR.Attribute.dense_elements?(attr) do
      raise ArgumentError, "cannot pop from dense elements attribute"
    end

    __MODULE__.Accessor.new(attr) |> __MODULE__.Accessor.pop(attr, key)
  end

  defdelegate unwrap_type(type_attr), to: MLIR.CAPI, as: :mlirTypeAttrGetValue

  defdelegate unwrap_string(str_attr), to: MLIR.CAPI, as: :mlirStringAttrGetValue

  def unwrap(%__MODULE__{} = attribute) do
    cond do
      Beaver.Native.to_term(mlirAttributeIsAType(attribute)) ->
        unwrap_type(attribute)

      Beaver.Native.to_term(mlirAttributeIsAString(attribute)) ->
        unwrap_string(attribute)
    end
  end
end