defmodule TflInterp do
@timeout 300000
@moduledoc """
Tensorflow lite intepreter for Elixir.
Deep Learning inference framework for embedded devices.
## Installation
This module is designed for Poncho-style. Therefore, it cannot be installed
by adding this module to your project's dependency list. Follow the steps
below to install.
Download `tfl_interp` to a directory of your choice. I recommend that you put
it in the same hierarchy as your Deep Learning project directory.
```shell
$ cd parent-of-your-project
$ git clone https://github.com/shoz-f/tfl_interp.git
```
Then you need to download the file set of Google Tensorflow and build
`tfl_intep` executable (Port extended called by Elixir) into ./priv.
Don't worry, `mix_cmake` utility will help you.
```shell
$ cd tfl_interp
$ mix deps.get
$ mix cmake --config
;-) It takes a few minutes to download and build Tensorflow.
```
Now you are ready. The figure below shows the directory structure of tfl_interp.
```
+- your-project
|
+- tfl_interp
+- _build
| +- .cmake_build --- Tensorflow is downloaded here
+- deps
+- lib
+- priv
| +- tfl_interp --- Elixir Port extended
+- src/
+- test/
+- CMakeLists.txt --- Cmake configuration script
+- mix.exs --- includes parameter for mix-cmake task
+- msys2.patch --- Patch script for MSYS2/MinGW64
```
## Basic Usage
To use TflInterp in your project, you add the path to `tfl_interp` above to
the `mix.exs`:
```elixir:mix.exs
def deps do
[
{:tfl_interp, path: "../tfl_interp"},
]
end
```
Then you put the trained model of Tensolflow lite in ./priv.
```shell
$ cp your-trained-model.tflite ./priv
```
The remaining task is to create a module that will interface with your Deep
Learning model. The module will probably have pre-processing and post-processing
in addition to inference processing, as in the code example below. TflInterp
provides only inference processing.
You put `use TflInterp` at the beginning of your module, specify the model path
in optional arguments. The inference section involves inputing data to the
model - `TflInterp.set_input_tensor/3`, executing it - `TflInterp.invoke/1`,
and extracting the results - `TflInterp.get_output_tensor/2`.
```elixr:your_model.ex
defmodule YourApp.YourModel do
use TflInterp, model: "priv/your-trained-model.tflite"
def predict(data) do
# preprocess
# to convert the data to be inferred to the input format of the model.
input_bin = convert-float32-binaries(data)
# inference
# typical I/O data for Tensorflow lite models is a serialized 32-bit float tensor.
output_bin =
__MODULE__
|> TflInterp.set_input_tensor(0, input_bin)
|> TflInterp.invoke()
|> TflInterp.get_output_tensor(0)
# postprocess
# add your post-processing here.
# you may need to reshape output_bin to tensor at first.
tensor = output_bin
|> Nx.from_binary({:f, 32})
|> Nx.reshape({size-x, size-y, :auto})
* your-postprocessing *
...
end
end
```
"""
defmacro __using__(opts) do
quote generated: true, location: :keep do
use GenServer
def start_link(opts) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end
def init(opts) do
executable = Application.app_dir(:tfl_interp, "priv/tfl_interp")
opts = Keyword.merge(unquote(opts), opts)
tfl_model = Keyword.get(opts, :model)
tfl_label = Keyword.get(opts, :label, "none")
tfl_opts = Keyword.get(opts, :opts, "")
port = Port.open({:spawn_executable, executable}, [
{:args, String.split(tfl_opts) ++ [tfl_model, tfl_label]},
{:packet, 4},
:binary
])
{:ok, %{port: port}}
end
def handle_call(cmd_line, _from, state) do
Port.command(state.port, cmd_line)
response = receive do
{_, {:data, <<result::binary>>}} -> {:ok, result}
after
Keyword.get(unquote(opts), :timeout, 300000) -> {:timeout}
end
{:reply, response, state}
end
def terminate(_reason, state) do
Port.close(state.port)
end
end
end
@doc """
Get the propaty of the tflite model.
## Parameters
* mod - modules' names
"""
def info(mod) do
cmd = 0
case GenServer.call(mod, <<cmd::8>>, @timeout) do
{:ok, result} -> Poison.decode(result)
any -> any
end
end
@doc """
Stop the tflite interpreter.
## Parameters
* mod - modules' names
"""
def stop(mod) do
GenServer.stop(mod)
end
@doc """
Put a flat binary to the input tensor on the interpreter.
## Parameters
* mod - modules' names
* index - index of input tensor in the model
* bin - input data - flat binary, cf. serialized tensor
"""
def set_input_tensor(mod, index, bin) do
cmd = 1
case GenServer.call(mod, <<cmd::8, index::8, bin::binary>>, @timeout) do
{:ok, result} -> Poison.decode(result)
any -> any
end
mod
end
@doc """
Invoke prediction.
## Parameters
* mod - modules' names
"""
def invoke(mod) do
cmd = 2
case GenServer.call(mod, <<cmd::8>>, @timeout) do
{:ok, result} -> Poison.decode(result)
any -> any
end
mod
end
@doc """
Get the flat binary from the output tensor on the interpreter"
## Parameters
* mod - modules' names
* index - index of output tensor in the model
"""
def get_output_tensor(mod, index) do
cmd = 3
case GenServer.call(mod, <<cmd::8, index::8>>, @timeout) do
{:ok, result} -> result
any -> any
end
end
@doc """
Execute post processing: nms.
## Parameters
* mod - modules' names
* num_boxes - number of candidate boxes
* num_class - number of category class
* boxes - binaries, serialized boxes tensor[`num_boxes`][4]; dtype: float32
* scores - binaries, serialized score tensor[`num_boxes`][`num_class`]; dtype: float32
* iou_threshold - IOU threshold
* score_threshold - score cutoff threshold
* sigma - soft IOU parameter
"""
def non_max_suppression_multi_class(mod, {num_boxes, num_class}, boxes, scores, iou_threshold \\ 0.5, score_threshold \\ 0.25, sigma \\ 0.0) do
cmd = 4
case GenServer.call(mod, <<cmd::8, num_boxes::little-integer-32, num_class::little-integer-32, iou_threshold::little-float-32, score_threshold::little-float-32, sigma::little-float-32>> <> boxes <> scores, @timeout) do
{:ok, result} -> Poison.decode(result)
any -> any
end
end
end