defmodule Mix.Tasks.DetectImage do
@moduledoc """
Image detection mix task: `mix help detect_image`
Command line arguments:
- `-m`, `--model`: *Required*. File path of .tflite file.
- `-i`, `--input`: *Required*. Image to process.
- `-l`, `--labels`: File path of labels file.
- `-t`, `--threshold`: Default to `0.4`. Score threshold for detected objects.
- `-c`, `--count`: Default to `1`. Number of times to run inference.
- `-j`, `--jobs`: Number of threads for the interpreter (only valid for CPU).
Code based on [detect_image.py](https://github.com/google-coral/pycoral/blob/master/examples/detect_image.py)
"""
use Mix.Task
alias TFLiteBEAM.Interpreter
alias TFLiteBEAM.InterpreterBuilder
alias TFLiteBEAM.TFLiteTensor
alias TFLiteBEAM.FlatBufferModel
@shortdoc "Object Detection"
def run(argv) do
{args, _, _} =
OptionParser.parse(argv,
strict: [
model: :string,
input: :string,
labels: :string,
threshold: :float,
count: :integer,
jobs: :integer
],
aliases: [
m: :model,
i: :input,
l: :labels,
t: :threshold,
c: :count,
j: :jobs
]
)
default_values = [threshold: 0.4, count: 1, jobs: System.schedulers_online()]
args =
Keyword.merge(args, default_values, fn _k, user, default ->
if user == nil do
default
else
user
end
end)
model = load_model(args[:model])
input_image = %StbImage{shape: {h, w, _c}} = load_input(args[:input])
labels = load_labels(args[:labels])
interpreter = make_interpreter(model, args[:jobs])
:ok = Interpreter.allocate_tensors(interpreter)
[input_tensor_number | _] = Interpreter.inputs!(interpreter)
output_tensor_numbers = Interpreter.outputs!(interpreter)
if Enum.count(output_tensor_numbers) != 4 do
raise ArgumentError, "Object detection models should have 4 output tensors"
end
%TFLiteTensor{} = input_tensor = Interpreter.tensor(interpreter, input_tensor_number)
if input_tensor.type != {:u, 8} do
raise ArgumentError, "Only support uint8 input type."
end
{height, width} =
case input_tensor.shape do
{_n, height, width, _c} ->
{height, width}
{_n, height, width} ->
{height, width}
shape ->
raise RuntimeError, "not sure the input shape, got #{inspect(shape)}"
end
scale = min(height / h, width / w)
{h, w} = {trunc(h * scale), trunc(w * scale)}
input_image =
StbImage.resize(input_image, h, w)
|> StbImage.to_nx()
|> Nx.new_axis(0)
Nx.broadcast(0, input_tensor.shape)
|> Nx.as_type(:u8)
|> Nx.put_slice([0, 0, 0, 0], input_image)
|> Nx.to_binary()
|> then(&TFLiteTensor.set_data(input_tensor, &1))
IO.puts("----INFERENCE TIME----")
for _ <- 1..args[:count] do
start_time = :os.system_time(:microsecond)
Interpreter.invoke!(interpreter)
end_time = :os.system_time(:microsecond)
inference_time = (end_time - start_time) / 1000.0
IO.puts("#{Float.round(inference_time, 1)}ms")
end
signature_list = Interpreter.get_signature_defs!(interpreter)
{count_tensor_id, scores_tensor_id, class_ids_tensor_id, boxes_tensor_id} =
if signature_list != nil do
signature_list = Map.values(signature_list)
if Enum.count(signature_list) > 1 do
raise ArgumentError, "Only support model with one signature."
else
count_tensor_id = signature_list[:outputs][:output_0]
scores_tensor_id = signature_list[:outputs][:output_1]
class_ids_tensor_id = signature_list[:outputs][:output_2]
boxes_tensor_id = signature_list[:outputs][:output_3]
{count_tensor_id, scores_tensor_id, class_ids_tensor_id, boxes_tensor_id}
end
else
%TFLiteTensor{} =
output_tensor_3 = Interpreter.tensor(interpreter, Enum.at(output_tensor_numbers, 3))
if output_tensor_3.shape == {1} do
boxes_tensor_id = Enum.at(output_tensor_numbers, 0)
class_ids_tensor_id = Enum.at(output_tensor_numbers, 1)
scores_tensor_id = Enum.at(output_tensor_numbers, 2)
count_tensor_id = Enum.at(output_tensor_numbers, 3)
{count_tensor_id, scores_tensor_id, class_ids_tensor_id, boxes_tensor_id}
else
boxes_tensor_id = Enum.at(output_tensor_numbers, 1)
class_ids_tensor_id = Enum.at(output_tensor_numbers, 3)
scores_tensor_id = Enum.at(output_tensor_numbers, 0)
count_tensor_id = Enum.at(output_tensor_numbers, 2)
{count_tensor_id, scores_tensor_id, class_ids_tensor_id, boxes_tensor_id}
end
end
boxes =
Interpreter.tensor(interpreter, boxes_tensor_id)
|> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
|> take_first_and_reshape()
class_ids =
Interpreter.tensor(interpreter, class_ids_tensor_id)
|> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
|> take_first_and_reshape()
scores =
Interpreter.tensor(interpreter, scores_tensor_id)
|> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
|> take_first_and_reshape()
count =
Interpreter.tensor(interpreter, count_tensor_id)
|> TFLiteTensor.to_nx(backend: Nx.BinaryBackend)
|> Nx.to_flat_list()
|> hd()
|> trunc()
{sx, sy} = {height / scale, width / scale}
Enum.each(0..(count - 1), fn index ->
score =
Nx.take(scores, index)
|> Nx.to_flat_list()
|> hd()
if score >= args[:threshold] do
[ymin, xmin, ymax, xmax] =
Nx.take(boxes, index)
|> Nx.multiply(scale)
|> Nx.to_flat_list()
{xmin, xmax} = {trunc(sx * xmin), trunc(sx * xmax)}
{ymin, ymax} = {trunc(sy * ymin), trunc(sy * ymax)}
class_id =
Nx.take(class_ids, index)
|> Nx.to_flat_list()
|> hd()
|> trunc()
class_str =
if labels != nil do
Enum.at(labels, class_id)
else
class_id
end
IO.puts("#{class_str}")
IO.puts(" id : #{class_id}")
IO.puts(" score: #{Float.round(score, 3)}")
IO.puts(" bbox : #{inspect([ymin, xmin, ymax, xmax])}")
end
end)
end
defp load_model(nil) do
raise ArgumentError, "empty value for argument '--model'"
end
defp load_model(model_path) do
FlatBufferModel.build_from_buffer(File.read!(model_path))
end
defp load_input(nil) do
raise ArgumentError, "empty value for argument '--input'"
end
defp load_input(input_path) do
with {:ok, input_image} <- StbImage.read_file(input_path) do
input_image
else
{:error, error} ->
raise RuntimeError, error
end
end
defp load_labels(nil), do: nil
defp load_labels(label_file_path) do
File.read!(label_file_path)
|> String.split("\n")
end
defp make_interpreter(model, num_jobs) do
resolver = TFLiteBEAM.Ops.Builtin.BuiltinResolver.new!()
builder = InterpreterBuilder.new!(model, resolver)
interpreter = Interpreter.new!()
InterpreterBuilder.set_num_threads!(builder, num_jobs)
:ok = InterpreterBuilder.build!(builder, interpreter)
Interpreter.set_num_threads!(interpreter, num_jobs)
interpreter
end
defp take_first_and_reshape(tensor) do
shape = Tuple.delete_at(Nx.shape(tensor), 0)
Nx.take(tensor, 0)
|> Nx.reshape(shape)
end
end