if !Code.ensure_loaded?(Kino.SmartCell) do
defmodule Evision.SmartCell.ML.TrainData do
end
else
defmodule Evision.SmartCell.ML.TrainData do
use Kino.JS, assets_path: "lib/assets"
use Kino.JS.Live
use Kino.SmartCell, name: "Evision: Train Data"
alias Evision.SmartCell.Helper, as: ESCH
@smartcell_id "evision.ml.traindata"
@properties %{
"x" => %{
:type => :string,
},
"x_type" => %{
:type => :string,
:opts => [must_in: ["s32", "f32"]],
:default => "f32"
},
"y" => %{
:type => :string,
},
"y_type" => %{
:type => :string,
:opts => [must_in: ["s32", "f32"]],
:default => "s32"
},
"data_layout" => %{
:type => :string,
:opts => [must_in: ["row", "col"]],
:default => "row"
},
"split_ratio" => %{
:type => :number,
:opts => [minimum: 0.0, maximum: 1.0],
:default => 0.8
},
"shuffle_dataset" => %{
:type => :boolean,
:default => true
},
"to_variable" => %{
:type => :string,
:default => "dataset"
},
}
@spec id :: String.t()
def id, do: @smartcell_id
@spec properties :: map()
def properties, do: @properties
@spec defaults :: map()
def defaults do
Map.new(Enum.map(@properties, fn {field, field_specs} ->
{field, field_specs[:default]}
end))
end
@impl true
def init(attrs, ctx) do
fields =
Enum.map(@properties, fn {field, field_specs} ->
{field, attrs[field] || field_specs[:default]}
end)
{:ok, assign(ctx, fields: Map.new(fields), id: @smartcell_id)}
end
@impl true
def handle_connect(ctx) do
{:ok, %{fields: ctx.assigns.fields, id: @smartcell_id}, ctx}
end
@impl true
def handle_event("update_field", %{"field" => field, "value" => value}, ctx) do
updated_fields = to_updates(ctx.assigns.fields, field, value)
ctx = update(ctx, :fields, &Map.merge(&1, updated_fields))
broadcast_event(ctx, "update", %{"fields" => updated_fields})
{:noreply, ctx}
end
def to_updates(_fields, name, value) do
property = @properties[name]
%{name => ESCH.to_update(value, property[:type], Access.get(property, :opts))}
end
@impl true
def to_attrs(%{assigns: %{fields: fields}}) do
fields
end
@impl true
def to_source(attrs) do
get_quoted_code(attrs)
|> Kino.SmartCell.quoted_to_string()
end
def get_quoted_code(attrs) do
quote do
unquote(ESCH.quoted_var(attrs["to_variable"])) =
Evision.ML.TrainData.create(
Evision.Mat.from_nx(Nx.tensor(unquote(ESCH.quoted_var(attrs["x"])), type: unquote(String.to_atom(attrs["x_type"])), backend: Evision.Backend)),
unquote(data_layout(attrs["data_layout"])),
Evision.Mat.from_nx(Nx.tensor(unquote(ESCH.quoted_var(attrs["y"])), type: unquote(String.to_atom(attrs["y_type"])), backend: Evision.Backend))
)
|> Evision.ML.TrainData.setTrainTestSplitRatio(unquote(attrs["split_ratio"]), shuffle: unquote(attrs["shuffle_dataset"]))
IO.puts("#Samples: #{Evision.ML.TrainData.getNSamples(unquote(ESCH.quoted_var(attrs["to_variable"])))}")
IO.puts("#Training samples: #{Evision.ML.TrainData.getNTrainSamples(unquote(ESCH.quoted_var(attrs["to_variable"])))}")
IO.puts("#Test samples: #{Evision.ML.TrainData.getNTestSamples(unquote(ESCH.quoted_var(attrs["to_variable"])))}")
end
end
def data_layout("row") do
quote do
Evision.cv_ROW_SAMPLE()
end
end
def data_layout("col") do
quote do
Evision.cv_COL_SAMPLE()
end
end
def get_calc_error(module, traindata_var, to_variable) do
quote do
unquote(ESCH.quoted_var(to_variable))
|> unquote(module).calcError(unquote(ESCH.quoted_var(traindata_var)), false)
|> then(fn r ->
case r do
{:error, error_message} ->
raise error_message
{error, _} ->
IO.puts("Training Error: #{error}")
end
end)
unquote(ESCH.quoted_var(to_variable))
|> unquote(module).calcError(unquote(ESCH.quoted_var(traindata_var)), true)
|> then(fn r ->
case r do
{:error, error_message} ->
raise error_message
{error, _} ->
IO.puts("Test Error: #{error}")
end
end)
end
end
end
end