lib/smartcell/ml_dtrees.ex

if !Code.ensure_loaded?(Kino.SmartCell) do
  defmodule Evision.SmartCell.ML.DTrees do
  end
else
  defmodule Evision.SmartCell.ML.DTrees do
    use Kino.JS, assets_path: "lib/assets"
    use Kino.JS.Live
    use Kino.SmartCell, name: "Evision: Decision Tree"

    alias Evision.SmartCell.Helper, as: ESCH
    alias Evision.SmartCell.ML.TrainData

    @smartcell_id "evision.ml.dtrees"

    @properties %{
      "data_from" => %{
        :type => :string,
        :opts => [must_in: ["traindata_var", "traindata"]],
        :default => "traindata_var"
      },
      "traindata_var" => %{
        :type => :string,
        :default => "dataset"
      },

      # DTrees
      "max_depth" => %{
        :type => :integer,
        :opts => [minimum: 1],
        :default => 4
      },
      "max_categories" => %{
        :type => :integer,
        :opts => [minimum: 2],
        :default => 2
      },
      "min_sample_count" => %{
        :type => :integer,
        :opts => [minimum: 1],
        :default => 10
      },
      "cv_folds" => %{
        :type => :integer,
        :opts => [minimum: 0],
        :default => 0
      },
      "to_variable" => %{
        :type => :string,
        :default => "dtree"
      }
    }
    @inner_to_module %{
      "traindata" => TrainData
    }

    @spec id :: String.t()
    def id, do: @smartcell_id

    @spec properties :: map()
    def properties, do: @properties

    @spec defaults :: map()
    def defaults do
      Map.new(
        Enum.map(@properties, fn {field, field_specs} ->
          {field, field_specs[:default]}
        end)
      )
    end

    @impl true
    def init(attrs, ctx) do
      # load from file or fill empty entries with default values
      fields =
        Map.new(
          Enum.map(@properties, fn {field, field_specs} ->
            {field, attrs[field] || field_specs[:default]}
          end)
        )

      # traindata
      key = "traindata"

      fields =
        ESCH.update_key_with_module(fields, key, @inner_to_module[key], fn fields, key ->
          fields["data_from"] == key
        end)

      info = [id: @smartcell_id, fields: fields]
      {:ok, assign(ctx, info)}
    end

    @impl true
    def handle_connect(ctx) do
      {:ok, %{id: ctx.assigns.id, fields: ctx.assigns.fields}, ctx}
    end

    @impl true
    def handle_event("update_field", %{"field" => field, "value" => value}, ctx) do
      updated_fields =
        case String.split(field, ".", parts: 2) do
          [inner, forward] ->
            ESCH.to_inner_updates(inner, @inner_to_module[inner], forward, value, ctx)

          [field] ->
            to_updates(ctx.assigns.fields, field, value)
        end

      ctx = update(ctx, :fields, &Map.merge(&1, updated_fields))
      broadcast_event(ctx, "update", %{"fields" => updated_fields})
      {:noreply, ctx}
    end

    def to_updates(_fields, name = "data_from", value) do
      property = @properties[name]
      fields = %{name => ESCH.to_update(value, property[:type], Access.get(property, :opts))}

      key = "traindata"

      ESCH.update_key_with_module(fields, key, @inner_to_module[key], fn fields, key ->
        fields["data_from"] == key
      end)
    end

    def to_updates(_fields, name, value) do
      property = @properties[name]
      %{name => ESCH.to_update(value, property[:type], Access.get(property, :opts))}
    end

    @impl true
    def to_attrs(%{assigns: %{fields: fields}}) do
      fields
    end

    @impl true
    def to_source(attrs) do
      get_quoted_code(attrs)
      |> Kino.SmartCell.quoted_to_string()
    end

    def get_quoted_code(attrs) do
      quote do
        unquote(ESCH.quoted_var(attrs["to_variable"])) =
          Evision.ML.DTrees.create()
          |> Evision.ML.DTrees.setMaxDepth(unquote(attrs["max_depth"]))
          |> Evision.ML.DTrees.setMaxCategories(unquote(attrs["max_categories"]))
          |> Evision.ML.DTrees.setCVFolds(unquote(attrs["cv_folds"]))
          |> Evision.ML.DTrees.setMinSampleCount(unquote(attrs["min_sample_count"]))

        unquote(train_on_dataset(attrs))
      end
    end

    defp train_on_dataset(%{
           "data_from" => "traindata_var",
           "traindata_var" => traindata_var,
           "to_variable" => to_variable
         }) do
      quote do
        Evision.ML.DTrees.train(
          unquote(ESCH.quoted_var(to_variable)),
          unquote(ESCH.quoted_var(traindata_var))
        )

        unquote(TrainData.get_calc_error(Evision.ML.SVM, traindata_var, to_variable))
      end
    end

    defp train_on_dataset(%{
           "data_from" => "traindata",
           "traindata" => traindata_attrs,
           "to_variable" => to_variable
         }) do
      dataset_variable = traindata_attrs["to_variable"]

      quote do
        unquote(TrainData.get_quoted_code(traindata_attrs))

        Evision.ML.DTrees.train(
          unquote(ESCH.quoted_var(to_variable)),
          unquote(ESCH.quoted_var(dataset_variable))
        )

        unquote(TrainData.get_calc_error(Evision.ML.SVM, dataset_variable, to_variable))
      end
    end
  end
end